From 2c1e35981f7064f293669109097eb4b8c4942692 Mon Sep 17 00:00:00 2001 From: "gcf-owl-bot[bot]" <78513119+gcf-owl-bot[bot]@users.noreply.github.com> Date: Wed, 8 Jan 2025 20:56:00 +0000 Subject: [PATCH] feat: [google-ai-generativelanguage] Add GoogleSearch tool type (#13411) - [ ] Regenerate this pull request now. BEGIN_COMMIT_OVERRIDE feat: Add GoogleSearch tool type feat: Add return type `Schema response` to function declarations feat: Add id to FunctionCall and FunctionResponse feat: Add civic_integrity toggle to generation_config feat: Add response_modalities to generation_config feat: Add voice_config to generation_config docs: Update safety filter list to include civic_integrity feat: Add image_safety block_reason + finish_reason feat: Add v1alpha feat: Add BidiGenerateContent + all the necessary protos feat: Add TuningMultiturnExample END_COMMIT_OVERRIDE feat: Add return type `Schema response` to function declarations feat: Add id to FunctionCall and FunctionResponse feat: Add civic_integrity toggle to generation_config feat: Add response_modalities to generation_config feat: Add voice_config to generation_config docs: Update safety filter list to include civic_integrity feat: Add image_safety block_reason + finish_reason PiperOrigin-RevId: 713080033 Source-Link: https://github.com/googleapis/googleapis/commit/130b113520c1b3fffd90f67198681b5fe84de464 Source-Link: https://github.com/googleapis/googleapis-gen/commit/d3455bab27a5c7f8c39e53482649468baa021ce0 Copy-Tag: eyJwIjoicGFja2FnZXMvZ29vZ2xlLWFpLWdlbmVyYXRpdmVsYW5ndWFnZS8uT3dsQm90LnlhbWwiLCJoIjoiZDM0NTViYWIyN2E1YzdmOGMzOWU1MzQ4MjY0OTQ2OGJhYTAyMWNlMCJ9 BEGIN_NESTED_COMMIT feat: [google-ai-generativelanguage] Add civic_integrity toggle docs: Update safety filter list to include civic_integrity feat: Add image_safety block_reason + finish_reason PiperOrigin-RevId: 713079488 Source-Link: https://github.com/googleapis/googleapis/commit/23553a97fb9afa8b39b66c79a51227c7abb2af95 Source-Link: https://github.com/googleapis/googleapis-gen/commit/28d644442dc9bb3d0d3d6f1ee5f3a022f8ff0c94 Copy-Tag: eyJwIjoicGFja2FnZXMvZ29vZ2xlLWFpLWdlbmVyYXRpdmVsYW5ndWFnZS8uT3dsQm90LnlhbWwiLCJoIjoiMjhkNjQ0NDQyZGM5YmIzZDBkM2Q2ZjFlZTVmM2EwMjJmOGZmMGM5NCJ9 END_NESTED_COMMIT BEGIN_NESTED_COMMIT feat: [google-ai-generativelanguage] Add v1alpha feat: Add BidiGenerateContent + all the necessary protos feat: Add TuningMultiturnExample PiperOrigin-RevId: 713078684 Source-Link: https://github.com/googleapis/googleapis/commit/ae0e102e4744749d019f37209c85b4e892dbc3a4 Source-Link: https://github.com/googleapis/googleapis-gen/commit/970d7ef8da844744f1de8d527a94ebd25dc75384 Copy-Tag: eyJwIjoicGFja2FnZXMvZ29vZ2xlLWFpLWdlbmVyYXRpdmVsYW5ndWFnZS8uT3dsQm90LnlhbWwiLCJoIjoiOTcwZDdlZjhkYTg0NDc0NGYxZGU4ZDUyN2E5NGViZDI1ZGM3NTM4NCJ9 END_NESTED_COMMIT --------- Co-authored-by: Owl Bot Co-authored-by: Victor Chudnovsky --- .../cache_service.rst | 10 + .../discuss_service.rst | 6 + .../file_service.rst | 10 + .../generative_service.rst | 6 + .../model_service.rst | 10 + .../permission_service.rst | 10 + .../prediction_service.rst | 6 + .../retriever_service.rst | 10 + .../generativelanguage_v1alpha/services_.rst | 14 + .../text_service.rst | 6 + .../generativelanguage_v1alpha/types_.rst | 6 + .../docs/index.rst | 8 + .../google/ai/generativelanguage/__init__.py | 6 + .../ai/generativelanguage/gapic_version.py | 2 +- .../ai/generativelanguage_v1/gapic_version.py | 2 +- .../generative_service/async_client.py | 4 +- .../services/generative_service/client.py | 4 +- .../types/generative_service.py | 24 +- .../ai/generativelanguage_v1alpha/__init__.py | 413 + .../gapic_metadata.json | 1020 + .../gapic_version.py | 16 + .../ai/generativelanguage_v1alpha/py.typed | 2 + .../services/__init__.py | 15 + .../services/cache_service/__init__.py | 22 + .../services/cache_service/async_client.py | 950 + .../services/cache_service/client.py | 1340 ++ .../services/cache_service/pagers.py | 197 + .../cache_service/transports/README.rst | 9 + .../cache_service/transports/__init__.py | 36 + .../services/cache_service/transports/base.py | 263 + .../services/cache_service/transports/grpc.py | 517 + .../cache_service/transports/grpc_asyncio.py | 573 + .../services/cache_service/transports/rest.py | 1428 ++ .../cache_service/transports/rest_base.py | 399 + .../services/discuss_service/__init__.py | 22 + .../services/discuss_service/async_client.py | 740 + .../services/discuss_service/client.py | 1128 ++ .../discuss_service/transports/README.rst | 9 + .../discuss_service/transports/__init__.py | 36 + .../discuss_service/transports/base.py | 213 + .../discuss_service/transports/grpc.py | 431 + .../transports/grpc_asyncio.py | 468 + .../discuss_service/transports/rest.py | 909 + .../discuss_service/transports/rest_base.py | 268 + .../services/file_service/__init__.py | 22 + .../services/file_service/async_client.py | 780 + .../services/file_service/client.py | 1165 ++ .../services/file_service/pagers.py | 197 + .../file_service/transports/README.rst | 9 + .../file_service/transports/__init__.py | 36 + .../services/file_service/transports/base.py | 239 + .../services/file_service/transports/grpc.py | 472 + .../file_service/transports/grpc_asyncio.py | 523 + .../services/file_service/transports/rest.py | 1191 ++ .../file_service/transports/rest_base.py | 323 + .../services/generative_service/__init__.py | 22 + .../generative_service/async_client.py | 1366 ++ .../services/generative_service/client.py | 1751 ++ .../generative_service/transports/README.rst | 9 + .../generative_service/transports/__init__.py | 36 + .../generative_service/transports/base.py | 298 + .../generative_service/transports/grpc.py | 591 + .../transports/grpc_asyncio.py | 654 + .../generative_service/transports/rest.py | 1726 ++ .../transports/rest_base.py | 510 + .../services/model_service/__init__.py | 22 + .../services/model_service/async_client.py | 1261 ++ .../services/model_service/client.py | 1646 ++ .../services/model_service/pagers.py | 353 + .../model_service/transports/README.rst | 9 + .../model_service/transports/__init__.py | 36 + .../services/model_service/transports/base.py | 290 + .../services/model_service/transports/grpc.py | 582 + .../model_service/transports/grpc_asyncio.py | 655 + .../services/model_service/transports/rest.py | 1819 ++ .../model_service/transports/rest_base.py | 476 + .../services/permission_service/__init__.py | 22 + .../permission_service/async_client.py | 1150 ++ .../services/permission_service/client.py | 1532 ++ .../services/permission_service/pagers.py | 197 + .../permission_service/transports/README.rst | 9 + .../permission_service/transports/__init__.py | 36 + .../permission_service/transports/base.py | 272 + .../permission_service/transports/grpc.py | 541 + .../transports/grpc_asyncio.py | 604 + .../permission_service/transports/rest.py | 1671 ++ .../transports/rest_base.py | 493 + .../services/prediction_service/__init__.py | 22 + .../prediction_service/async_client.py | 535 + .../services/prediction_service/client.py | 929 + .../prediction_service/transports/README.rst | 9 + .../prediction_service/transports/__init__.py | 36 + .../prediction_service/transports/base.py | 196 + .../prediction_service/transports/grpc.py | 396 + .../transports/grpc_asyncio.py | 429 + .../prediction_service/transports/rest.py | 700 + .../transports/rest_base.py | 211 + .../services/retriever_service/__init__.py | 22 + .../retriever_service/async_client.py | 2499 +++ .../services/retriever_service/client.py | 2875 +++ .../services/retriever_service/pagers.py | 509 + .../retriever_service/transports/README.rst | 9 + .../retriever_service/transports/__init__.py | 36 + .../retriever_service/transports/base.py | 481 + .../retriever_service/transports/grpc.py | 908 + .../transports/grpc_asyncio.py | 1048 + .../retriever_service/transports/rest.py | 4067 ++++ .../retriever_service/transports/rest_base.py | 1196 ++ .../services/text_service/__init__.py | 22 + .../services/text_service/async_client.py | 1005 + .../services/text_service/client.py | 1386 ++ .../text_service/transports/README.rst | 9 + .../text_service/transports/__init__.py | 36 + .../services/text_service/transports/base.py | 246 + .../services/text_service/transports/grpc.py | 485 + .../text_service/transports/grpc_asyncio.py | 536 + .../services/text_service/transports/rest.py | 1293 ++ .../text_service/transports/rest_base.py | 387 + .../types/__init__.py | 369 + .../types/cache_service.py | 167 + .../types/cached_content.py | 182 + .../types/citation.py | 101 + .../types/content.py | 819 + .../types/discuss_service.py | 356 + .../generativelanguage_v1alpha/types/file.py | 174 + .../types/file_service.py | 145 + .../types/generative_service.py | 2139 ++ .../generativelanguage_v1alpha/types/model.py | 171 + .../types/model_service.py | 332 + .../types/permission.py | 141 + .../types/permission_service.py | 220 + .../types/prediction_service.py | 79 + .../types/retriever.py | 411 + .../types/retriever_service.py | 793 + .../types/safety.py | 276 + .../types/text_service.py | 441 + .../types/tuned_model.py | 542 + .../ai/generativelanguage_v1beta/__init__.py | 6 + .../gapic_version.py | 2 +- .../services/file_service/async_client.py | 2 + .../services/file_service/client.py | 2 + .../services/file_service/transports/rest.py | 2 + .../generative_service/async_client.py | 4 +- .../services/generative_service/client.py | 4 +- .../types/__init__.py | 6 + .../types/content.py | 45 + .../generativelanguage_v1beta/types/file.py | 2 + .../types/generative_service.py | 129 +- .../gapic_version.py | 2 +- .../gapic_version.py | 2 +- ...che_service_create_cached_content_async.py | 51 + ...ache_service_create_cached_content_sync.py | 51 + ...che_service_delete_cached_content_async.py | 50 + ...ache_service_delete_cached_content_sync.py | 50 + ..._cache_service_get_cached_content_async.py | 52 + ...d_cache_service_get_cached_content_sync.py | 52 + ...ache_service_list_cached_contents_async.py | 52 + ...cache_service_list_cached_contents_sync.py | 52 + ...che_service_update_cached_content_async.py | 51 + ...ache_service_update_cached_content_sync.py | 51 + ...cuss_service_count_message_tokens_async.py | 56 + ...scuss_service_count_message_tokens_sync.py | 56 + ..._discuss_service_generate_message_async.py | 56 + ...d_discuss_service_generate_message_sync.py | 56 + ...enerated_file_service_create_file_async.py | 51 + ...generated_file_service_create_file_sync.py | 51 + ...enerated_file_service_delete_file_async.py | 50 + ...generated_file_service_delete_file_sync.py | 50 + ...a_generated_file_service_get_file_async.py | 52 + ...ha_generated_file_service_get_file_sync.py | 52 + ...generated_file_service_list_files_async.py | 52 + ..._generated_file_service_list_files_sync.py | 52 + ...tive_service_batch_embed_contents_async.py | 56 + ...ative_service_batch_embed_contents_sync.py | 56 + ...ive_service_bidi_generate_content_async.py | 66 + ...tive_service_bidi_generate_content_sync.py | 66 + ...d_generative_service_count_tokens_async.py | 52 + ...ed_generative_service_count_tokens_sync.py | 52 + ..._generative_service_embed_content_async.py | 52 + ...d_generative_service_embed_content_sync.py | 52 + ...enerative_service_generate_answer_async.py | 53 + ...generative_service_generate_answer_sync.py | 53 + ...nerative_service_generate_content_async.py | 52 + ...enerative_service_generate_content_sync.py | 52 + ...e_service_stream_generate_content_async.py | 53 + ...ve_service_stream_generate_content_sync.py | 53 + ..._model_service_create_tuned_model_async.py | 55 + ...d_model_service_create_tuned_model_sync.py | 55 + ..._model_service_delete_tuned_model_async.py | 50 + ...d_model_service_delete_tuned_model_sync.py | 50 + ...generated_model_service_get_model_async.py | 52 + ..._generated_model_service_get_model_sync.py | 52 + ...ted_model_service_get_tuned_model_async.py | 52 + ...ated_model_service_get_tuned_model_sync.py | 52 + ...nerated_model_service_list_models_async.py | 52 + ...enerated_model_service_list_models_sync.py | 52 + ...d_model_service_list_tuned_models_async.py | 52 + ...ed_model_service_list_tuned_models_sync.py | 52 + ..._model_service_update_tuned_model_async.py | 51 + ...d_model_service_update_tuned_model_sync.py | 51 + ...mission_service_create_permission_async.py | 52 + ...rmission_service_create_permission_sync.py | 52 + ...mission_service_delete_permission_async.py | 50 + ...rmission_service_delete_permission_sync.py | 50 + ...permission_service_get_permission_async.py | 52 + ..._permission_service_get_permission_sync.py | 52 + ...rmission_service_list_permissions_async.py | 53 + ...ermission_service_list_permissions_sync.py | 53 + ...ission_service_transfer_ownership_async.py | 53 + ...mission_service_transfer_ownership_sync.py | 53 + ...mission_service_update_permission_async.py | 51 + ...rmission_service_update_permission_sync.py | 51 + ...erated_prediction_service_predict_async.py | 56 + ...nerated_prediction_service_predict_sync.py | 56 + ...iever_service_batch_create_chunks_async.py | 56 + ...riever_service_batch_create_chunks_sync.py | 56 + ...iever_service_batch_delete_chunks_async.py | 53 + ...riever_service_batch_delete_chunks_sync.py | 53 + ...iever_service_batch_update_chunks_async.py | 55 + ...riever_service_batch_update_chunks_sync.py | 55 + ...ed_retriever_service_create_chunk_async.py | 56 + ...ted_retriever_service_create_chunk_sync.py | 56 + ...d_retriever_service_create_corpus_async.py | 51 + ...ed_retriever_service_create_corpus_sync.py | 51 + ...retriever_service_create_document_async.py | 52 + ..._retriever_service_create_document_sync.py | 52 + ...ed_retriever_service_delete_chunk_async.py | 50 + ...ted_retriever_service_delete_chunk_sync.py | 50 + ...d_retriever_service_delete_corpus_async.py | 50 + ...ed_retriever_service_delete_corpus_sync.py | 50 + ...retriever_service_delete_document_async.py | 50 + ..._retriever_service_delete_document_sync.py | 50 + ...rated_retriever_service_get_chunk_async.py | 52 + ...erated_retriever_service_get_chunk_sync.py | 52 + ...ated_retriever_service_get_corpus_async.py | 52 + ...rated_retriever_service_get_corpus_sync.py | 52 + ...ed_retriever_service_get_document_async.py | 52 + ...ted_retriever_service_get_document_sync.py | 52 + ...ted_retriever_service_list_chunks_async.py | 53 + ...ated_retriever_service_list_chunks_sync.py | 53 + ...ed_retriever_service_list_corpora_async.py | 52 + ...ted_retriever_service_list_corpora_sync.py | 52 + ..._retriever_service_list_documents_async.py | 53 + ...d_retriever_service_list_documents_sync.py | 53 + ...ed_retriever_service_query_corpus_async.py | 53 + ...ted_retriever_service_query_corpus_sync.py | 53 + ..._retriever_service_query_document_async.py | 53 + ...d_retriever_service_query_document_sync.py | 53 + ...ed_retriever_service_update_chunk_async.py | 55 + ...ted_retriever_service_update_chunk_sync.py | 55 + ...d_retriever_service_update_corpus_async.py | 51 + ...ed_retriever_service_update_corpus_sync.py | 51 + ...retriever_service_update_document_async.py | 51 + ..._retriever_service_update_document_sync.py | 51 + ...ted_text_service_batch_embed_text_async.py | 52 + ...ated_text_service_batch_embed_text_sync.py | 52 + ...ed_text_service_count_text_tokens_async.py | 56 + ...ted_text_service_count_text_tokens_sync.py | 56 + ...generated_text_service_embed_text_async.py | 52 + ..._generated_text_service_embed_text_sync.py | 52 + ...erated_text_service_generate_text_async.py | 56 + ...nerated_text_service_generate_text_sync.py | 56 + ...adata_google.ai.generativelanguage.v1.json | 2 +- ..._google.ai.generativelanguage.v1alpha.json | 9183 +++++++++ ...a_google.ai.generativelanguage.v1beta.json | 2 +- ..._google.ai.generativelanguage.v1beta2.json | 2 +- ..._google.ai.generativelanguage.v1beta3.json | 2 +- ...xup_generativelanguage_v1alpha_keywords.py | 231 + .../generativelanguage_v1alpha/__init__.py | 15 + .../test_cache_service.py | 6111 ++++++ .../test_discuss_service.py | 3756 ++++ .../test_file_service.py | 4644 +++++ .../test_generative_service.py | 6939 +++++++ .../test_model_service.py | 7901 ++++++++ .../test_permission_service.py | 6934 +++++++ .../test_prediction_service.py | 3001 +++ .../test_retriever_service.py | 16486 ++++++++++++++++ .../test_text_service.py | 5101 +++++ .../test_cache_service.py | 28 +- 279 files changed, 143967 insertions(+), 27 deletions(-) create mode 100644 packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/cache_service.rst create mode 100644 packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/discuss_service.rst create mode 100644 packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/file_service.rst create mode 100644 packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/generative_service.rst create mode 100644 packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/model_service.rst create mode 100644 packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/permission_service.rst create mode 100644 packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/prediction_service.rst create mode 100644 packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/retriever_service.rst create mode 100644 packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/services_.rst create mode 100644 packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/text_service.rst create mode 100644 packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/types_.rst create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/gapic_metadata.json create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/gapic_version.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/py.typed create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/async_client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/pagers.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/README.rst create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/grpc.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/grpc_asyncio.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/rest.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/rest_base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/async_client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/README.rst create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/grpc.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/grpc_asyncio.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/rest.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/rest_base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/async_client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/pagers.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/README.rst create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/grpc.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/grpc_asyncio.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/rest.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/rest_base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/async_client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/README.rst create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/grpc.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/grpc_asyncio.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/rest.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/rest_base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/async_client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/pagers.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/README.rst create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/grpc.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/grpc_asyncio.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/rest.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/rest_base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/async_client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/pagers.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/README.rst create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/grpc.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/grpc_asyncio.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/rest.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/rest_base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/async_client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/README.rst create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/grpc.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/grpc_asyncio.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/rest.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/rest_base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/async_client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/pagers.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/README.rst create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/grpc.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/grpc_asyncio.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/rest.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/rest_base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/async_client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/client.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/README.rst create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/grpc.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/grpc_asyncio.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/rest.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/rest_base.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/__init__.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/cache_service.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/cached_content.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/citation.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/content.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/discuss_service.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/file.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/file_service.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/generative_service.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/model.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/model_service.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/permission.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/permission_service.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/prediction_service.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/retriever.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/retriever_service.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/safety.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/text_service.py create mode 100644 packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/tuned_model.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_create_cached_content_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_create_cached_content_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_delete_cached_content_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_delete_cached_content_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_get_cached_content_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_get_cached_content_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_list_cached_contents_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_list_cached_contents_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_update_cached_content_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_update_cached_content_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_count_message_tokens_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_count_message_tokens_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_generate_message_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_generate_message_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_create_file_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_create_file_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_delete_file_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_delete_file_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_get_file_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_get_file_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_list_files_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_list_files_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_batch_embed_contents_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_batch_embed_contents_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_bidi_generate_content_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_bidi_generate_content_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_count_tokens_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_count_tokens_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_embed_content_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_embed_content_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_answer_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_answer_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_content_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_content_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_stream_generate_content_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_stream_generate_content_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_create_tuned_model_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_create_tuned_model_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_delete_tuned_model_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_delete_tuned_model_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_model_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_model_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_tuned_model_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_tuned_model_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_models_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_models_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_tuned_models_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_tuned_models_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_update_tuned_model_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_update_tuned_model_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_create_permission_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_create_permission_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_delete_permission_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_delete_permission_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_get_permission_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_get_permission_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_list_permissions_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_list_permissions_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_transfer_ownership_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_transfer_ownership_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_update_permission_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_update_permission_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_prediction_service_predict_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_prediction_service_predict_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_create_chunks_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_create_chunks_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_delete_chunks_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_delete_chunks_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_update_chunks_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_update_chunks_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_chunk_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_chunk_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_corpus_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_corpus_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_document_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_document_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_chunk_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_chunk_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_corpus_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_corpus_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_document_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_document_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_chunk_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_chunk_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_corpus_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_corpus_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_document_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_document_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_chunks_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_chunks_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_corpora_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_corpora_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_documents_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_documents_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_corpus_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_corpus_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_document_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_document_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_chunk_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_chunk_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_corpus_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_corpus_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_document_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_document_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_batch_embed_text_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_batch_embed_text_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_count_text_tokens_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_count_text_tokens_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_embed_text_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_embed_text_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_generate_text_async.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_generate_text_sync.py create mode 100644 packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1alpha.json create mode 100644 packages/google-ai-generativelanguage/scripts/fixup_generativelanguage_v1alpha_keywords.py create mode 100644 packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/__init__.py create mode 100644 packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_cache_service.py create mode 100644 packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_discuss_service.py create mode 100644 packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_file_service.py create mode 100644 packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_generative_service.py create mode 100644 packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_model_service.py create mode 100644 packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_permission_service.py create mode 100644 packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_prediction_service.py create mode 100644 packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_retriever_service.py create mode 100644 packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_text_service.py diff --git a/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/cache_service.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/cache_service.rst new file mode 100644 index 000000000000..6d1b22455271 --- /dev/null +++ b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/cache_service.rst @@ -0,0 +1,10 @@ +CacheService +------------------------------ + +.. automodule:: google.ai.generativelanguage_v1alpha.services.cache_service + :members: + :inherited-members: + +.. automodule:: google.ai.generativelanguage_v1alpha.services.cache_service.pagers + :members: + :inherited-members: diff --git a/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/discuss_service.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/discuss_service.rst new file mode 100644 index 000000000000..4996aa39dbf3 --- /dev/null +++ b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/discuss_service.rst @@ -0,0 +1,6 @@ +DiscussService +-------------------------------- + +.. automodule:: google.ai.generativelanguage_v1alpha.services.discuss_service + :members: + :inherited-members: diff --git a/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/file_service.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/file_service.rst new file mode 100644 index 000000000000..fc5f5538643c --- /dev/null +++ b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/file_service.rst @@ -0,0 +1,10 @@ +FileService +----------------------------- + +.. automodule:: google.ai.generativelanguage_v1alpha.services.file_service + :members: + :inherited-members: + +.. automodule:: google.ai.generativelanguage_v1alpha.services.file_service.pagers + :members: + :inherited-members: diff --git a/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/generative_service.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/generative_service.rst new file mode 100644 index 000000000000..6c27b40a2a46 --- /dev/null +++ b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/generative_service.rst @@ -0,0 +1,6 @@ +GenerativeService +----------------------------------- + +.. automodule:: google.ai.generativelanguage_v1alpha.services.generative_service + :members: + :inherited-members: diff --git a/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/model_service.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/model_service.rst new file mode 100644 index 000000000000..7f119e39a597 --- /dev/null +++ b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/model_service.rst @@ -0,0 +1,10 @@ +ModelService +------------------------------ + +.. automodule:: google.ai.generativelanguage_v1alpha.services.model_service + :members: + :inherited-members: + +.. automodule:: google.ai.generativelanguage_v1alpha.services.model_service.pagers + :members: + :inherited-members: diff --git a/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/permission_service.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/permission_service.rst new file mode 100644 index 000000000000..0873faf0c999 --- /dev/null +++ b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/permission_service.rst @@ -0,0 +1,10 @@ +PermissionService +----------------------------------- + +.. automodule:: google.ai.generativelanguage_v1alpha.services.permission_service + :members: + :inherited-members: + +.. automodule:: google.ai.generativelanguage_v1alpha.services.permission_service.pagers + :members: + :inherited-members: diff --git a/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/prediction_service.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/prediction_service.rst new file mode 100644 index 000000000000..6a93c8cc2eb7 --- /dev/null +++ b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/prediction_service.rst @@ -0,0 +1,6 @@ +PredictionService +----------------------------------- + +.. automodule:: google.ai.generativelanguage_v1alpha.services.prediction_service + :members: + :inherited-members: diff --git a/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/retriever_service.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/retriever_service.rst new file mode 100644 index 000000000000..5d9b9f330225 --- /dev/null +++ b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/retriever_service.rst @@ -0,0 +1,10 @@ +RetrieverService +---------------------------------- + +.. automodule:: google.ai.generativelanguage_v1alpha.services.retriever_service + :members: + :inherited-members: + +.. automodule:: google.ai.generativelanguage_v1alpha.services.retriever_service.pagers + :members: + :inherited-members: diff --git a/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/services_.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/services_.rst new file mode 100644 index 000000000000..17ca3d964c15 --- /dev/null +++ b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/services_.rst @@ -0,0 +1,14 @@ +Services for Google Ai Generativelanguage v1alpha API +===================================================== +.. toctree:: + :maxdepth: 2 + + cache_service + discuss_service + file_service + generative_service + model_service + permission_service + prediction_service + retriever_service + text_service diff --git a/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/text_service.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/text_service.rst new file mode 100644 index 000000000000..2c27b0cc60b7 --- /dev/null +++ b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/text_service.rst @@ -0,0 +1,6 @@ +TextService +----------------------------- + +.. automodule:: google.ai.generativelanguage_v1alpha.services.text_service + :members: + :inherited-members: diff --git a/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/types_.rst b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/types_.rst new file mode 100644 index 000000000000..ea7512835fb8 --- /dev/null +++ b/packages/google-ai-generativelanguage/docs/generativelanguage_v1alpha/types_.rst @@ -0,0 +1,6 @@ +Types for Google Ai Generativelanguage v1alpha API +================================================== + +.. automodule:: google.ai.generativelanguage_v1alpha.types + :members: + :show-inheritance: diff --git a/packages/google-ai-generativelanguage/docs/index.rst b/packages/google-ai-generativelanguage/docs/index.rst index 5688bf71543b..b973b49f5784 100644 --- a/packages/google-ai-generativelanguage/docs/index.rst +++ b/packages/google-ai-generativelanguage/docs/index.rst @@ -22,6 +22,14 @@ API Reference generativelanguage_v1/services_ generativelanguage_v1/types_ +API Reference +------------- +.. toctree:: + :maxdepth: 2 + + generativelanguage_v1alpha/services_ + generativelanguage_v1alpha/types_ + API Reference ------------- .. toctree:: diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage/__init__.py index 750b54051c3f..fb8789d5bab9 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage/__init__.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage/__init__.py @@ -144,11 +144,14 @@ GroundingMetadata, GroundingSupport, LogprobsResult, + PrebuiltVoiceConfig, RetrievalMetadata, SearchEntryPoint, Segment, SemanticRetrieverConfig, + SpeechConfig, TaskType, + VoiceConfig, ) from google.ai.generativelanguage_v1beta.types.model import Model from google.ai.generativelanguage_v1beta.types.model_service import ( @@ -330,10 +333,13 @@ "GroundingMetadata", "GroundingSupport", "LogprobsResult", + "PrebuiltVoiceConfig", "RetrievalMetadata", "SearchEntryPoint", "Segment", "SemanticRetrieverConfig", + "SpeechConfig", + "VoiceConfig", "TaskType", "Model", "CreateTunedModelMetadata", diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage/gapic_version.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage/gapic_version.py index 0b6dbde2b051..558c8aab67c5 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage/gapic_version.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "0.6.14" # {x-release-please-version} +__version__ = "0.0.0" # {x-release-please-version} diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/gapic_version.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/gapic_version.py index 0b6dbde2b051..558c8aab67c5 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/gapic_version.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "0.6.14" # {x-release-please-version} +__version__ = "0.0.0" # {x-release-please-version} diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/services/generative_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/services/generative_service/async_client.py index 2a0471ed2dc9..88a9c247893c 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/services/generative_service/async_client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/services/generative_service/async_client.py @@ -353,7 +353,7 @@ async def sample_generate_content(): Required. The name of the ``Model`` to use for generating the completion. - Format: ``name=models/{model}``. + Format: ``models/{model}``. This corresponds to the ``model`` field on the ``request`` instance; if ``request`` is provided, this @@ -494,7 +494,7 @@ async def sample_stream_generate_content(): Required. The name of the ``Model`` to use for generating the completion. - Format: ``name=models/{model}``. + Format: ``models/{model}``. This corresponds to the ``model`` field on the ``request`` instance; if ``request`` is provided, this diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/services/generative_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/services/generative_service/client.py index 258a77134a14..5ad720e8cf2e 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/services/generative_service/client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/services/generative_service/client.py @@ -740,7 +740,7 @@ def sample_generate_content(): Required. The name of the ``Model`` to use for generating the completion. - Format: ``name=models/{model}``. + Format: ``models/{model}``. This corresponds to the ``model`` field on the ``request`` instance; if ``request`` is provided, this @@ -878,7 +878,7 @@ def sample_stream_generate_content(): Required. The name of the ``Model`` to use for generating the completion. - Format: ``name=models/{model}``. + Format: ``models/{model}``. This corresponds to the ``model`` field on the ``request`` instance; if ``request`` is provided, this diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/types/generative_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/types/generative_service.py index f52aa49108f1..bcc8bd6f3ed6 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/types/generative_service.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1/types/generative_service.py @@ -98,7 +98,7 @@ class GenerateContentRequest(proto.Message): Required. The name of the ``Model`` to use for generating the completion. - Format: ``name=models/{model}``. + Format: ``models/{model}``. contents (MutableSequence[google.ai.generativelanguage_v1.types.Content]): Required. The content of the current conversation with the model. @@ -124,8 +124,8 @@ class GenerateContentRequest(proto.Message): will use the default safety setting for that category. Harm categories HARM_CATEGORY_HATE_SPEECH, HARM_CATEGORY_SEXUALLY_EXPLICIT, - HARM_CATEGORY_DANGEROUS_CONTENT, HARM_CATEGORY_HARASSMENT - are supported. Refer to the + HARM_CATEGORY_DANGEROUS_CONTENT, HARM_CATEGORY_HARASSMENT, + HARM_CATEGORY_CIVIC_INTEGRITY are supported. Refer to the `guide `__ for detailed information on available safety settings. Also refer to the `Safety @@ -287,6 +287,11 @@ class GenerationConfig(proto.Message): [Candidate.logprobs_result][google.ai.generativelanguage.v1.Candidate.logprobs_result]. This field is a member of `oneof`_ ``_logprobs``. + enable_enhanced_civic_answers (bool): + Optional. Enables enhanced civic answers. It + may not be available for all models. + + This field is a member of `oneof`_ ``_enable_enhanced_civic_answers``. """ candidate_count: int = proto.Field( @@ -338,6 +343,11 @@ class GenerationConfig(proto.Message): number=18, optional=True, ) + enable_enhanced_civic_answers: bool = proto.Field( + proto.BOOL, + number=19, + optional=True, + ) class GenerateContentResponse(proto.Message): @@ -397,12 +407,16 @@ class BlockReason(proto.Enum): included from the terminology blocklist. PROHIBITED_CONTENT (4): Prompt was blocked due to prohibited content. + IMAGE_SAFETY (5): + Candidates blocked due to unsafe image + generation content. """ BLOCK_REASON_UNSPECIFIED = 0 SAFETY = 1 OTHER = 2 BLOCKLIST = 3 PROHIBITED_CONTENT = 4 + IMAGE_SAFETY = 5 block_reason: "GenerateContentResponse.PromptFeedback.BlockReason" = ( proto.Field( @@ -548,6 +562,9 @@ class FinishReason(proto.Enum): MALFORMED_FUNCTION_CALL (10): The function call generated by the model is invalid. + IMAGE_SAFETY (11): + Token generation stopped because generated + images contain safety violations. """ FINISH_REASON_UNSPECIFIED = 0 STOP = 1 @@ -560,6 +577,7 @@ class FinishReason(proto.Enum): PROHIBITED_CONTENT = 8 SPII = 9 MALFORMED_FUNCTION_CALL = 10 + IMAGE_SAFETY = 11 index: int = proto.Field( proto.INT32, diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/__init__.py new file mode 100644 index 000000000000..33d2e7c26a31 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/__init__.py @@ -0,0 +1,413 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +__version__ = package_version.__version__ + + +from .services.cache_service import CacheServiceAsyncClient, CacheServiceClient +from .services.discuss_service import DiscussServiceAsyncClient, DiscussServiceClient +from .services.file_service import FileServiceAsyncClient, FileServiceClient +from .services.generative_service import ( + GenerativeServiceAsyncClient, + GenerativeServiceClient, +) +from .services.model_service import ModelServiceAsyncClient, ModelServiceClient +from .services.permission_service import ( + PermissionServiceAsyncClient, + PermissionServiceClient, +) +from .services.prediction_service import ( + PredictionServiceAsyncClient, + PredictionServiceClient, +) +from .services.retriever_service import ( + RetrieverServiceAsyncClient, + RetrieverServiceClient, +) +from .services.text_service import TextServiceAsyncClient, TextServiceClient +from .types.cache_service import ( + CreateCachedContentRequest, + DeleteCachedContentRequest, + GetCachedContentRequest, + ListCachedContentsRequest, + ListCachedContentsResponse, + UpdateCachedContentRequest, +) +from .types.cached_content import CachedContent +from .types.citation import CitationMetadata, CitationSource +from .types.content import ( + Blob, + CodeExecution, + CodeExecutionResult, + Content, + DynamicRetrievalConfig, + ExecutableCode, + FileData, + FunctionCall, + FunctionCallingConfig, + FunctionDeclaration, + FunctionResponse, + GoogleSearchRetrieval, + GroundingPassage, + GroundingPassages, + Part, + Schema, + Tool, + ToolConfig, + Type, +) +from .types.discuss_service import ( + CountMessageTokensRequest, + CountMessageTokensResponse, + Example, + GenerateMessageRequest, + GenerateMessageResponse, + Message, + MessagePrompt, +) +from .types.file import File, VideoMetadata +from .types.file_service import ( + CreateFileRequest, + CreateFileResponse, + DeleteFileRequest, + GetFileRequest, + ListFilesRequest, + ListFilesResponse, +) +from .types.generative_service import ( + AttributionSourceId, + BatchEmbedContentsRequest, + BatchEmbedContentsResponse, + BidiGenerateContentClientContent, + BidiGenerateContentClientMessage, + BidiGenerateContentRealtimeInput, + BidiGenerateContentServerContent, + BidiGenerateContentServerMessage, + BidiGenerateContentSetup, + BidiGenerateContentSetupComplete, + BidiGenerateContentToolCall, + BidiGenerateContentToolCallCancellation, + BidiGenerateContentToolResponse, + Candidate, + ContentEmbedding, + CountTokensRequest, + CountTokensResponse, + EmbedContentRequest, + EmbedContentResponse, + GenerateAnswerRequest, + GenerateAnswerResponse, + GenerateContentRequest, + GenerateContentResponse, + GenerationConfig, + GroundingAttribution, + GroundingChunk, + GroundingMetadata, + GroundingSupport, + LogprobsResult, + PrebuiltVoiceConfig, + RetrievalMetadata, + SearchEntryPoint, + Segment, + SemanticRetrieverConfig, + SpeechConfig, + TaskType, + VoiceConfig, +) +from .types.model import Model +from .types.model_service import ( + CreateTunedModelMetadata, + CreateTunedModelRequest, + DeleteTunedModelRequest, + GetModelRequest, + GetTunedModelRequest, + ListModelsRequest, + ListModelsResponse, + ListTunedModelsRequest, + ListTunedModelsResponse, + UpdateTunedModelRequest, +) +from .types.permission import Permission +from .types.permission_service import ( + CreatePermissionRequest, + DeletePermissionRequest, + GetPermissionRequest, + ListPermissionsRequest, + ListPermissionsResponse, + TransferOwnershipRequest, + TransferOwnershipResponse, + UpdatePermissionRequest, +) +from .types.prediction_service import PredictRequest, PredictResponse +from .types.retriever import ( + Chunk, + ChunkData, + Condition, + Corpus, + CustomMetadata, + Document, + MetadataFilter, + StringList, +) +from .types.retriever_service import ( + BatchCreateChunksRequest, + BatchCreateChunksResponse, + BatchDeleteChunksRequest, + BatchUpdateChunksRequest, + BatchUpdateChunksResponse, + CreateChunkRequest, + CreateCorpusRequest, + CreateDocumentRequest, + DeleteChunkRequest, + DeleteCorpusRequest, + DeleteDocumentRequest, + GetChunkRequest, + GetCorpusRequest, + GetDocumentRequest, + ListChunksRequest, + ListChunksResponse, + ListCorporaRequest, + ListCorporaResponse, + ListDocumentsRequest, + ListDocumentsResponse, + QueryCorpusRequest, + QueryCorpusResponse, + QueryDocumentRequest, + QueryDocumentResponse, + RelevantChunk, + UpdateChunkRequest, + UpdateCorpusRequest, + UpdateDocumentRequest, +) +from .types.safety import ( + ContentFilter, + HarmCategory, + SafetyFeedback, + SafetyRating, + SafetySetting, +) +from .types.text_service import ( + BatchEmbedTextRequest, + BatchEmbedTextResponse, + CountTextTokensRequest, + CountTextTokensResponse, + Embedding, + EmbedTextRequest, + EmbedTextResponse, + GenerateTextRequest, + GenerateTextResponse, + TextCompletion, + TextPrompt, +) +from .types.tuned_model import ( + Dataset, + Hyperparameters, + TunedModel, + TunedModelSource, + TuningContent, + TuningExample, + TuningExamples, + TuningMultiturnExample, + TuningPart, + TuningSnapshot, + TuningTask, +) + +__all__ = ( + "CacheServiceAsyncClient", + "DiscussServiceAsyncClient", + "FileServiceAsyncClient", + "GenerativeServiceAsyncClient", + "ModelServiceAsyncClient", + "PermissionServiceAsyncClient", + "PredictionServiceAsyncClient", + "RetrieverServiceAsyncClient", + "TextServiceAsyncClient", + "AttributionSourceId", + "BatchCreateChunksRequest", + "BatchCreateChunksResponse", + "BatchDeleteChunksRequest", + "BatchEmbedContentsRequest", + "BatchEmbedContentsResponse", + "BatchEmbedTextRequest", + "BatchEmbedTextResponse", + "BatchUpdateChunksRequest", + "BatchUpdateChunksResponse", + "BidiGenerateContentClientContent", + "BidiGenerateContentClientMessage", + "BidiGenerateContentRealtimeInput", + "BidiGenerateContentServerContent", + "BidiGenerateContentServerMessage", + "BidiGenerateContentSetup", + "BidiGenerateContentSetupComplete", + "BidiGenerateContentToolCall", + "BidiGenerateContentToolCallCancellation", + "BidiGenerateContentToolResponse", + "Blob", + "CacheServiceClient", + "CachedContent", + "Candidate", + "Chunk", + "ChunkData", + "CitationMetadata", + "CitationSource", + "CodeExecution", + "CodeExecutionResult", + "Condition", + "Content", + "ContentEmbedding", + "ContentFilter", + "Corpus", + "CountMessageTokensRequest", + "CountMessageTokensResponse", + "CountTextTokensRequest", + "CountTextTokensResponse", + "CountTokensRequest", + "CountTokensResponse", + "CreateCachedContentRequest", + "CreateChunkRequest", + "CreateCorpusRequest", + "CreateDocumentRequest", + "CreateFileRequest", + "CreateFileResponse", + "CreatePermissionRequest", + "CreateTunedModelMetadata", + "CreateTunedModelRequest", + "CustomMetadata", + "Dataset", + "DeleteCachedContentRequest", + "DeleteChunkRequest", + "DeleteCorpusRequest", + "DeleteDocumentRequest", + "DeleteFileRequest", + "DeletePermissionRequest", + "DeleteTunedModelRequest", + "DiscussServiceClient", + "Document", + "DynamicRetrievalConfig", + "EmbedContentRequest", + "EmbedContentResponse", + "EmbedTextRequest", + "EmbedTextResponse", + "Embedding", + "Example", + "ExecutableCode", + "File", + "FileData", + "FileServiceClient", + "FunctionCall", + "FunctionCallingConfig", + "FunctionDeclaration", + "FunctionResponse", + "GenerateAnswerRequest", + "GenerateAnswerResponse", + "GenerateContentRequest", + "GenerateContentResponse", + "GenerateMessageRequest", + "GenerateMessageResponse", + "GenerateTextRequest", + "GenerateTextResponse", + "GenerationConfig", + "GenerativeServiceClient", + "GetCachedContentRequest", + "GetChunkRequest", + "GetCorpusRequest", + "GetDocumentRequest", + "GetFileRequest", + "GetModelRequest", + "GetPermissionRequest", + "GetTunedModelRequest", + "GoogleSearchRetrieval", + "GroundingAttribution", + "GroundingChunk", + "GroundingMetadata", + "GroundingPassage", + "GroundingPassages", + "GroundingSupport", + "HarmCategory", + "Hyperparameters", + "ListCachedContentsRequest", + "ListCachedContentsResponse", + "ListChunksRequest", + "ListChunksResponse", + "ListCorporaRequest", + "ListCorporaResponse", + "ListDocumentsRequest", + "ListDocumentsResponse", + "ListFilesRequest", + "ListFilesResponse", + "ListModelsRequest", + "ListModelsResponse", + "ListPermissionsRequest", + "ListPermissionsResponse", + "ListTunedModelsRequest", + "ListTunedModelsResponse", + "LogprobsResult", + "Message", + "MessagePrompt", + "MetadataFilter", + "Model", + "ModelServiceClient", + "Part", + "Permission", + "PermissionServiceClient", + "PrebuiltVoiceConfig", + "PredictRequest", + "PredictResponse", + "PredictionServiceClient", + "QueryCorpusRequest", + "QueryCorpusResponse", + "QueryDocumentRequest", + "QueryDocumentResponse", + "RelevantChunk", + "RetrievalMetadata", + "RetrieverServiceClient", + "SafetyFeedback", + "SafetyRating", + "SafetySetting", + "Schema", + "SearchEntryPoint", + "Segment", + "SemanticRetrieverConfig", + "SpeechConfig", + "StringList", + "TaskType", + "TextCompletion", + "TextPrompt", + "TextServiceClient", + "Tool", + "ToolConfig", + "TransferOwnershipRequest", + "TransferOwnershipResponse", + "TunedModel", + "TunedModelSource", + "TuningContent", + "TuningExample", + "TuningExamples", + "TuningMultiturnExample", + "TuningPart", + "TuningSnapshot", + "TuningTask", + "Type", + "UpdateCachedContentRequest", + "UpdateChunkRequest", + "UpdateCorpusRequest", + "UpdateDocumentRequest", + "UpdatePermissionRequest", + "UpdateTunedModelRequest", + "VideoMetadata", + "VoiceConfig", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/gapic_metadata.json b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/gapic_metadata.json new file mode 100644 index 000000000000..db219e407c36 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/gapic_metadata.json @@ -0,0 +1,1020 @@ + { + "comment": "This file maps proto services/RPCs to the corresponding library clients/methods", + "language": "python", + "libraryPackage": "google.ai.generativelanguage_v1alpha", + "protoPackage": "google.ai.generativelanguage.v1alpha", + "schema": "1.0", + "services": { + "CacheService": { + "clients": { + "grpc": { + "libraryClient": "CacheServiceClient", + "rpcs": { + "CreateCachedContent": { + "methods": [ + "create_cached_content" + ] + }, + "DeleteCachedContent": { + "methods": [ + "delete_cached_content" + ] + }, + "GetCachedContent": { + "methods": [ + "get_cached_content" + ] + }, + "ListCachedContents": { + "methods": [ + "list_cached_contents" + ] + }, + "UpdateCachedContent": { + "methods": [ + "update_cached_content" + ] + } + } + }, + "grpc-async": { + "libraryClient": "CacheServiceAsyncClient", + "rpcs": { + "CreateCachedContent": { + "methods": [ + "create_cached_content" + ] + }, + "DeleteCachedContent": { + "methods": [ + "delete_cached_content" + ] + }, + "GetCachedContent": { + "methods": [ + "get_cached_content" + ] + }, + "ListCachedContents": { + "methods": [ + "list_cached_contents" + ] + }, + "UpdateCachedContent": { + "methods": [ + "update_cached_content" + ] + } + } + }, + "rest": { + "libraryClient": "CacheServiceClient", + "rpcs": { + "CreateCachedContent": { + "methods": [ + "create_cached_content" + ] + }, + "DeleteCachedContent": { + "methods": [ + "delete_cached_content" + ] + }, + "GetCachedContent": { + "methods": [ + "get_cached_content" + ] + }, + "ListCachedContents": { + "methods": [ + "list_cached_contents" + ] + }, + "UpdateCachedContent": { + "methods": [ + "update_cached_content" + ] + } + } + } + } + }, + "DiscussService": { + "clients": { + "grpc": { + "libraryClient": "DiscussServiceClient", + "rpcs": { + "CountMessageTokens": { + "methods": [ + "count_message_tokens" + ] + }, + "GenerateMessage": { + "methods": [ + "generate_message" + ] + } + } + }, + "grpc-async": { + "libraryClient": "DiscussServiceAsyncClient", + "rpcs": { + "CountMessageTokens": { + "methods": [ + "count_message_tokens" + ] + }, + "GenerateMessage": { + "methods": [ + "generate_message" + ] + } + } + }, + "rest": { + "libraryClient": "DiscussServiceClient", + "rpcs": { + "CountMessageTokens": { + "methods": [ + "count_message_tokens" + ] + }, + "GenerateMessage": { + "methods": [ + "generate_message" + ] + } + } + } + } + }, + "FileService": { + "clients": { + "grpc": { + "libraryClient": "FileServiceClient", + "rpcs": { + "CreateFile": { + "methods": [ + "create_file" + ] + }, + "DeleteFile": { + "methods": [ + "delete_file" + ] + }, + "GetFile": { + "methods": [ + "get_file" + ] + }, + "ListFiles": { + "methods": [ + "list_files" + ] + } + } + }, + "grpc-async": { + "libraryClient": "FileServiceAsyncClient", + "rpcs": { + "CreateFile": { + "methods": [ + "create_file" + ] + }, + "DeleteFile": { + "methods": [ + "delete_file" + ] + }, + "GetFile": { + "methods": [ + "get_file" + ] + }, + "ListFiles": { + "methods": [ + "list_files" + ] + } + } + }, + "rest": { + "libraryClient": "FileServiceClient", + "rpcs": { + "CreateFile": { + "methods": [ + "create_file" + ] + }, + "DeleteFile": { + "methods": [ + "delete_file" + ] + }, + "GetFile": { + "methods": [ + "get_file" + ] + }, + "ListFiles": { + "methods": [ + "list_files" + ] + } + } + } + } + }, + "GenerativeService": { + "clients": { + "grpc": { + "libraryClient": "GenerativeServiceClient", + "rpcs": { + "BatchEmbedContents": { + "methods": [ + "batch_embed_contents" + ] + }, + "BidiGenerateContent": { + "methods": [ + "bidi_generate_content" + ] + }, + "CountTokens": { + "methods": [ + "count_tokens" + ] + }, + "EmbedContent": { + "methods": [ + "embed_content" + ] + }, + "GenerateAnswer": { + "methods": [ + "generate_answer" + ] + }, + "GenerateContent": { + "methods": [ + "generate_content" + ] + }, + "StreamGenerateContent": { + "methods": [ + "stream_generate_content" + ] + } + } + }, + "grpc-async": { + "libraryClient": "GenerativeServiceAsyncClient", + "rpcs": { + "BatchEmbedContents": { + "methods": [ + "batch_embed_contents" + ] + }, + "BidiGenerateContent": { + "methods": [ + "bidi_generate_content" + ] + }, + "CountTokens": { + "methods": [ + "count_tokens" + ] + }, + "EmbedContent": { + "methods": [ + "embed_content" + ] + }, + "GenerateAnswer": { + "methods": [ + "generate_answer" + ] + }, + "GenerateContent": { + "methods": [ + "generate_content" + ] + }, + "StreamGenerateContent": { + "methods": [ + "stream_generate_content" + ] + } + } + }, + "rest": { + "libraryClient": "GenerativeServiceClient", + "rpcs": { + "BatchEmbedContents": { + "methods": [ + "batch_embed_contents" + ] + }, + "BidiGenerateContent": { + "methods": [ + "bidi_generate_content" + ] + }, + "CountTokens": { + "methods": [ + "count_tokens" + ] + }, + "EmbedContent": { + "methods": [ + "embed_content" + ] + }, + "GenerateAnswer": { + "methods": [ + "generate_answer" + ] + }, + "GenerateContent": { + "methods": [ + "generate_content" + ] + }, + "StreamGenerateContent": { + "methods": [ + "stream_generate_content" + ] + } + } + } + } + }, + "ModelService": { + "clients": { + "grpc": { + "libraryClient": "ModelServiceClient", + "rpcs": { + "CreateTunedModel": { + "methods": [ + "create_tuned_model" + ] + }, + "DeleteTunedModel": { + "methods": [ + "delete_tuned_model" + ] + }, + "GetModel": { + "methods": [ + "get_model" + ] + }, + "GetTunedModel": { + "methods": [ + "get_tuned_model" + ] + }, + "ListModels": { + "methods": [ + "list_models" + ] + }, + "ListTunedModels": { + "methods": [ + "list_tuned_models" + ] + }, + "UpdateTunedModel": { + "methods": [ + "update_tuned_model" + ] + } + } + }, + "grpc-async": { + "libraryClient": "ModelServiceAsyncClient", + "rpcs": { + "CreateTunedModel": { + "methods": [ + "create_tuned_model" + ] + }, + "DeleteTunedModel": { + "methods": [ + "delete_tuned_model" + ] + }, + "GetModel": { + "methods": [ + "get_model" + ] + }, + "GetTunedModel": { + "methods": [ + "get_tuned_model" + ] + }, + "ListModels": { + "methods": [ + "list_models" + ] + }, + "ListTunedModels": { + "methods": [ + "list_tuned_models" + ] + }, + "UpdateTunedModel": { + "methods": [ + "update_tuned_model" + ] + } + } + }, + "rest": { + "libraryClient": "ModelServiceClient", + "rpcs": { + "CreateTunedModel": { + "methods": [ + "create_tuned_model" + ] + }, + "DeleteTunedModel": { + "methods": [ + "delete_tuned_model" + ] + }, + "GetModel": { + "methods": [ + "get_model" + ] + }, + "GetTunedModel": { + "methods": [ + "get_tuned_model" + ] + }, + "ListModels": { + "methods": [ + "list_models" + ] + }, + "ListTunedModels": { + "methods": [ + "list_tuned_models" + ] + }, + "UpdateTunedModel": { + "methods": [ + "update_tuned_model" + ] + } + } + } + } + }, + "PermissionService": { + "clients": { + "grpc": { + "libraryClient": "PermissionServiceClient", + "rpcs": { + "CreatePermission": { + "methods": [ + "create_permission" + ] + }, + "DeletePermission": { + "methods": [ + "delete_permission" + ] + }, + "GetPermission": { + "methods": [ + "get_permission" + ] + }, + "ListPermissions": { + "methods": [ + "list_permissions" + ] + }, + "TransferOwnership": { + "methods": [ + "transfer_ownership" + ] + }, + "UpdatePermission": { + "methods": [ + "update_permission" + ] + } + } + }, + "grpc-async": { + "libraryClient": "PermissionServiceAsyncClient", + "rpcs": { + "CreatePermission": { + "methods": [ + "create_permission" + ] + }, + "DeletePermission": { + "methods": [ + "delete_permission" + ] + }, + "GetPermission": { + "methods": [ + "get_permission" + ] + }, + "ListPermissions": { + "methods": [ + "list_permissions" + ] + }, + "TransferOwnership": { + "methods": [ + "transfer_ownership" + ] + }, + "UpdatePermission": { + "methods": [ + "update_permission" + ] + } + } + }, + "rest": { + "libraryClient": "PermissionServiceClient", + "rpcs": { + "CreatePermission": { + "methods": [ + "create_permission" + ] + }, + "DeletePermission": { + "methods": [ + "delete_permission" + ] + }, + "GetPermission": { + "methods": [ + "get_permission" + ] + }, + "ListPermissions": { + "methods": [ + "list_permissions" + ] + }, + "TransferOwnership": { + "methods": [ + "transfer_ownership" + ] + }, + "UpdatePermission": { + "methods": [ + "update_permission" + ] + } + } + } + } + }, + "PredictionService": { + "clients": { + "grpc": { + "libraryClient": "PredictionServiceClient", + "rpcs": { + "Predict": { + "methods": [ + "predict" + ] + } + } + }, + "grpc-async": { + "libraryClient": "PredictionServiceAsyncClient", + "rpcs": { + "Predict": { + "methods": [ + "predict" + ] + } + } + }, + "rest": { + "libraryClient": "PredictionServiceClient", + "rpcs": { + "Predict": { + "methods": [ + "predict" + ] + } + } + } + } + }, + "RetrieverService": { + "clients": { + "grpc": { + "libraryClient": "RetrieverServiceClient", + "rpcs": { + "BatchCreateChunks": { + "methods": [ + "batch_create_chunks" + ] + }, + "BatchDeleteChunks": { + "methods": [ + "batch_delete_chunks" + ] + }, + "BatchUpdateChunks": { + "methods": [ + "batch_update_chunks" + ] + }, + "CreateChunk": { + "methods": [ + "create_chunk" + ] + }, + "CreateCorpus": { + "methods": [ + "create_corpus" + ] + }, + "CreateDocument": { + "methods": [ + "create_document" + ] + }, + "DeleteChunk": { + "methods": [ + "delete_chunk" + ] + }, + "DeleteCorpus": { + "methods": [ + "delete_corpus" + ] + }, + "DeleteDocument": { + "methods": [ + "delete_document" + ] + }, + "GetChunk": { + "methods": [ + "get_chunk" + ] + }, + "GetCorpus": { + "methods": [ + "get_corpus" + ] + }, + "GetDocument": { + "methods": [ + "get_document" + ] + }, + "ListChunks": { + "methods": [ + "list_chunks" + ] + }, + "ListCorpora": { + "methods": [ + "list_corpora" + ] + }, + "ListDocuments": { + "methods": [ + "list_documents" + ] + }, + "QueryCorpus": { + "methods": [ + "query_corpus" + ] + }, + "QueryDocument": { + "methods": [ + "query_document" + ] + }, + "UpdateChunk": { + "methods": [ + "update_chunk" + ] + }, + "UpdateCorpus": { + "methods": [ + "update_corpus" + ] + }, + "UpdateDocument": { + "methods": [ + "update_document" + ] + } + } + }, + "grpc-async": { + "libraryClient": "RetrieverServiceAsyncClient", + "rpcs": { + "BatchCreateChunks": { + "methods": [ + "batch_create_chunks" + ] + }, + "BatchDeleteChunks": { + "methods": [ + "batch_delete_chunks" + ] + }, + "BatchUpdateChunks": { + "methods": [ + "batch_update_chunks" + ] + }, + "CreateChunk": { + "methods": [ + "create_chunk" + ] + }, + "CreateCorpus": { + "methods": [ + "create_corpus" + ] + }, + "CreateDocument": { + "methods": [ + "create_document" + ] + }, + "DeleteChunk": { + "methods": [ + "delete_chunk" + ] + }, + "DeleteCorpus": { + "methods": [ + "delete_corpus" + ] + }, + "DeleteDocument": { + "methods": [ + "delete_document" + ] + }, + "GetChunk": { + "methods": [ + "get_chunk" + ] + }, + "GetCorpus": { + "methods": [ + "get_corpus" + ] + }, + "GetDocument": { + "methods": [ + "get_document" + ] + }, + "ListChunks": { + "methods": [ + "list_chunks" + ] + }, + "ListCorpora": { + "methods": [ + "list_corpora" + ] + }, + "ListDocuments": { + "methods": [ + "list_documents" + ] + }, + "QueryCorpus": { + "methods": [ + "query_corpus" + ] + }, + "QueryDocument": { + "methods": [ + "query_document" + ] + }, + "UpdateChunk": { + "methods": [ + "update_chunk" + ] + }, + "UpdateCorpus": { + "methods": [ + "update_corpus" + ] + }, + "UpdateDocument": { + "methods": [ + "update_document" + ] + } + } + }, + "rest": { + "libraryClient": "RetrieverServiceClient", + "rpcs": { + "BatchCreateChunks": { + "methods": [ + "batch_create_chunks" + ] + }, + "BatchDeleteChunks": { + "methods": [ + "batch_delete_chunks" + ] + }, + "BatchUpdateChunks": { + "methods": [ + "batch_update_chunks" + ] + }, + "CreateChunk": { + "methods": [ + "create_chunk" + ] + }, + "CreateCorpus": { + "methods": [ + "create_corpus" + ] + }, + "CreateDocument": { + "methods": [ + "create_document" + ] + }, + "DeleteChunk": { + "methods": [ + "delete_chunk" + ] + }, + "DeleteCorpus": { + "methods": [ + "delete_corpus" + ] + }, + "DeleteDocument": { + "methods": [ + "delete_document" + ] + }, + "GetChunk": { + "methods": [ + "get_chunk" + ] + }, + "GetCorpus": { + "methods": [ + "get_corpus" + ] + }, + "GetDocument": { + "methods": [ + "get_document" + ] + }, + "ListChunks": { + "methods": [ + "list_chunks" + ] + }, + "ListCorpora": { + "methods": [ + "list_corpora" + ] + }, + "ListDocuments": { + "methods": [ + "list_documents" + ] + }, + "QueryCorpus": { + "methods": [ + "query_corpus" + ] + }, + "QueryDocument": { + "methods": [ + "query_document" + ] + }, + "UpdateChunk": { + "methods": [ + "update_chunk" + ] + }, + "UpdateCorpus": { + "methods": [ + "update_corpus" + ] + }, + "UpdateDocument": { + "methods": [ + "update_document" + ] + } + } + } + } + }, + "TextService": { + "clients": { + "grpc": { + "libraryClient": "TextServiceClient", + "rpcs": { + "BatchEmbedText": { + "methods": [ + "batch_embed_text" + ] + }, + "CountTextTokens": { + "methods": [ + "count_text_tokens" + ] + }, + "EmbedText": { + "methods": [ + "embed_text" + ] + }, + "GenerateText": { + "methods": [ + "generate_text" + ] + } + } + }, + "grpc-async": { + "libraryClient": "TextServiceAsyncClient", + "rpcs": { + "BatchEmbedText": { + "methods": [ + "batch_embed_text" + ] + }, + "CountTextTokens": { + "methods": [ + "count_text_tokens" + ] + }, + "EmbedText": { + "methods": [ + "embed_text" + ] + }, + "GenerateText": { + "methods": [ + "generate_text" + ] + } + } + }, + "rest": { + "libraryClient": "TextServiceClient", + "rpcs": { + "BatchEmbedText": { + "methods": [ + "batch_embed_text" + ] + }, + "CountTextTokens": { + "methods": [ + "count_text_tokens" + ] + }, + "EmbedText": { + "methods": [ + "embed_text" + ] + }, + "GenerateText": { + "methods": [ + "generate_text" + ] + } + } + } + } + } + } +} diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/gapic_version.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/gapic_version.py new file mode 100644 index 000000000000..558c8aab67c5 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/gapic_version.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +__version__ = "0.0.0" # {x-release-please-version} diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/py.typed b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/py.typed new file mode 100644 index 000000000000..38773eee6363 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561. +# The google-ai-generativelanguage package uses inline types. diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/__init__.py new file mode 100644 index 000000000000..8f6cf068242c --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/__init__.py new file mode 100644 index 000000000000..2f8bc2e1ba03 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .async_client import CacheServiceAsyncClient +from .client import CacheServiceClient + +__all__ = ( + "CacheServiceClient", + "CacheServiceAsyncClient", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/async_client.py new file mode 100644 index 000000000000..ff56bab13223 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/async_client.py @@ -0,0 +1,950 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry_async as retries +from google.api_core.client_options import ClientOptions +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore + +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import duration_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.cache_service import pagers +from google.ai.generativelanguage_v1alpha.types import ( + cached_content as gag_cached_content, +) +from google.ai.generativelanguage_v1alpha.types import cache_service +from google.ai.generativelanguage_v1alpha.types import cached_content +from google.ai.generativelanguage_v1alpha.types import content + +from .client import CacheServiceClient +from .transports.base import DEFAULT_CLIENT_INFO, CacheServiceTransport +from .transports.grpc_asyncio import CacheServiceGrpcAsyncIOTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class CacheServiceAsyncClient: + """API for managing cache of content (CachedContent resources) + that can be used in GenerativeService requests. This way + generate content requests can benefit from preprocessing work + being done earlier, possibly lowering their computational cost. + It is intended to be used with large contexts. + """ + + _client: CacheServiceClient + + # Copy defaults from the synchronous client for use here. + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = CacheServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = CacheServiceClient.DEFAULT_MTLS_ENDPOINT + _DEFAULT_ENDPOINT_TEMPLATE = CacheServiceClient._DEFAULT_ENDPOINT_TEMPLATE + _DEFAULT_UNIVERSE = CacheServiceClient._DEFAULT_UNIVERSE + + cached_content_path = staticmethod(CacheServiceClient.cached_content_path) + parse_cached_content_path = staticmethod( + CacheServiceClient.parse_cached_content_path + ) + model_path = staticmethod(CacheServiceClient.model_path) + parse_model_path = staticmethod(CacheServiceClient.parse_model_path) + common_billing_account_path = staticmethod( + CacheServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + CacheServiceClient.parse_common_billing_account_path + ) + common_folder_path = staticmethod(CacheServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(CacheServiceClient.parse_common_folder_path) + common_organization_path = staticmethod(CacheServiceClient.common_organization_path) + parse_common_organization_path = staticmethod( + CacheServiceClient.parse_common_organization_path + ) + common_project_path = staticmethod(CacheServiceClient.common_project_path) + parse_common_project_path = staticmethod( + CacheServiceClient.parse_common_project_path + ) + common_location_path = staticmethod(CacheServiceClient.common_location_path) + parse_common_location_path = staticmethod( + CacheServiceClient.parse_common_location_path + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + CacheServiceAsyncClient: The constructed client. + """ + return CacheServiceClient.from_service_account_info.__func__(CacheServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + CacheServiceAsyncClient: The constructed client. + """ + return CacheServiceClient.from_service_account_file.__func__(CacheServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return CacheServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> CacheServiceTransport: + """Returns the transport used by the client instance. + + Returns: + CacheServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._client._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used + by the client instance. + """ + return self._client._universe_domain + + get_transport_class = CacheServiceClient.get_transport_class + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[str, CacheServiceTransport, Callable[..., CacheServiceTransport]] + ] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the cache service async client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,CacheServiceTransport,Callable[..., CacheServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the CacheServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = CacheServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.CacheServiceAsyncClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "universeDomain": getattr( + self._client._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._client._transport._credentials).__module__}.{type(self._client._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._client._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "credentialsType": None, + }, + ) + + async def list_cached_contents( + self, + request: Optional[Union[cache_service.ListCachedContentsRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListCachedContentsAsyncPager: + r"""Lists CachedContents. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_list_cached_contents(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListCachedContentsRequest( + ) + + # Make the request + page_result = client.list_cached_contents(request=request) + + # Handle the response + async for response in page_result: + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.ListCachedContentsRequest, dict]]): + The request object. Request to list CachedContents. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.cache_service.pagers.ListCachedContentsAsyncPager: + Response with CachedContents list. + + Iterating over this object will yield + results and resolve additional pages + automatically. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, cache_service.ListCachedContentsRequest): + request = cache_service.ListCachedContentsRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_cached_contents + ] + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListCachedContentsAsyncPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def create_cached_content( + self, + request: Optional[Union[cache_service.CreateCachedContentRequest, dict]] = None, + *, + cached_content: Optional[gag_cached_content.CachedContent] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gag_cached_content.CachedContent: + r"""Creates CachedContent resource. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_create_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateCachedContentRequest( + ) + + # Make the request + response = await client.create_cached_content(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.CreateCachedContentRequest, dict]]): + The request object. Request to create CachedContent. + cached_content (:class:`google.ai.generativelanguage_v1alpha.types.CachedContent`): + Required. The cached content to + create. + + This corresponds to the ``cached_content`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.CachedContent: + Content that has been preprocessed + and can be used in subsequent request to + GenerativeService. + + Cached content can be only used with + model it was created for. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([cached_content]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, cache_service.CreateCachedContentRequest): + request = cache_service.CreateCachedContentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if cached_content is not None: + request.cached_content = cached_content + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_cached_content + ] + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_cached_content( + self, + request: Optional[Union[cache_service.GetCachedContentRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> cached_content.CachedContent: + r"""Reads CachedContent resource. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_get_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetCachedContentRequest( + name="name_value", + ) + + # Make the request + response = await client.get_cached_content(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.GetCachedContentRequest, dict]]): + The request object. Request to read CachedContent. + name (:class:`str`): + Required. The resource name referring to the content + cache entry. Format: ``cachedContents/{id}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.CachedContent: + Content that has been preprocessed + and can be used in subsequent request to + GenerativeService. + + Cached content can be only used with + model it was created for. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, cache_service.GetCachedContentRequest): + request = cache_service.GetCachedContentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_cached_content + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_cached_content( + self, + request: Optional[Union[cache_service.UpdateCachedContentRequest, dict]] = None, + *, + cached_content: Optional[gag_cached_content.CachedContent] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gag_cached_content.CachedContent: + r"""Updates CachedContent resource (only expiration is + updatable). + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_update_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateCachedContentRequest( + ) + + # Make the request + response = await client.update_cached_content(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.UpdateCachedContentRequest, dict]]): + The request object. Request to update CachedContent. + cached_content (:class:`google.ai.generativelanguage_v1alpha.types.CachedContent`): + Required. The content cache entry to + update + + This corresponds to the ``cached_content`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + The list of fields to update. + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.CachedContent: + Content that has been preprocessed + and can be used in subsequent request to + GenerativeService. + + Cached content can be only used with + model it was created for. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([cached_content, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, cache_service.UpdateCachedContentRequest): + request = cache_service.UpdateCachedContentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if cached_content is not None: + request.cached_content = cached_content + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_cached_content + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("cached_content.name", request.cached_content.name),) + ), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_cached_content( + self, + request: Optional[Union[cache_service.DeleteCachedContentRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Deletes CachedContent resource. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_delete_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteCachedContentRequest( + name="name_value", + ) + + # Make the request + await client.delete_cached_content(request=request) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.DeleteCachedContentRequest, dict]]): + The request object. Request to delete CachedContent. + name (:class:`str`): + Required. The resource name referring to the content + cache entry Format: ``cachedContents/{id}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, cache_service.DeleteCachedContentRequest): + request = cache_service.DeleteCachedContentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_cached_content + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "CacheServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("CacheServiceAsyncClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/client.py new file mode 100644 index 000000000000..ef87e6cbf854 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/client.py @@ -0,0 +1,1340 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import os +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) +import warnings + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import duration_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.cache_service import pagers +from google.ai.generativelanguage_v1alpha.types import ( + cached_content as gag_cached_content, +) +from google.ai.generativelanguage_v1alpha.types import cache_service +from google.ai.generativelanguage_v1alpha.types import cached_content +from google.ai.generativelanguage_v1alpha.types import content + +from .transports.base import DEFAULT_CLIENT_INFO, CacheServiceTransport +from .transports.grpc import CacheServiceGrpcTransport +from .transports.grpc_asyncio import CacheServiceGrpcAsyncIOTransport +from .transports.rest import CacheServiceRestTransport + + +class CacheServiceClientMeta(type): + """Metaclass for the CacheService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = OrderedDict() # type: Dict[str, Type[CacheServiceTransport]] + _transport_registry["grpc"] = CacheServiceGrpcTransport + _transport_registry["grpc_asyncio"] = CacheServiceGrpcAsyncIOTransport + _transport_registry["rest"] = CacheServiceRestTransport + + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[CacheServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class CacheServiceClient(metaclass=CacheServiceClientMeta): + """API for managing cache of content (CachedContent resources) + that can be used in GenerativeService requests. This way + generate content requests can benefit from preprocessing work + being done earlier, possibly lowering their computational cost. + It is intended to be used with large contexts. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + _DEFAULT_ENDPOINT_TEMPLATE = "generativelanguage.{UNIVERSE_DOMAIN}" + _DEFAULT_UNIVERSE = "googleapis.com" + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + CacheServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + CacheServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> CacheServiceTransport: + """Returns the transport used by the client instance. + + Returns: + CacheServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def cached_content_path( + id: str, + ) -> str: + """Returns a fully-qualified cached_content string.""" + return "cachedContents/{id}".format( + id=id, + ) + + @staticmethod + def parse_cached_content_path(path: str) -> Dict[str, str]: + """Parses a cached_content path into its component segments.""" + m = re.match(r"^cachedContents/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def model_path( + model: str, + ) -> str: + """Returns a fully-qualified model string.""" + return "models/{model}".format( + model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parses a model path into its component segments.""" + m = re.match(r"^models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path( + billing_account: str, + ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path( + folder: str, + ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format( + folder=folder, + ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path( + organization: str, + ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format( + organization=organization, + ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path( + project: str, + ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format( + project=project, + ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path( + project: str, + location: str, + ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Deprecated. Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + + warnings.warn( + "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.", + DeprecationWarning, + ) + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + @staticmethod + def _read_environment_variables(): + """Returns the environment variables used by the client. + + Returns: + Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE, + GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables. + + Raises: + ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not + any of ["true", "false"]. + google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT + is not any of ["auto", "never", "always"]. + """ + use_client_cert = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() + universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + + @staticmethod + def _get_client_cert_source(provided_cert_source, use_cert_flag): + """Return the client cert source to be used by the client. + + Args: + provided_cert_source (bytes): The client certificate source provided. + use_cert_flag (bool): A flag indicating whether to use the client certificate. + + Returns: + bytes or None: The client cert source to be used by the client. + """ + client_cert_source = None + if use_cert_flag: + if provided_cert_source: + client_cert_source = provided_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + return client_cert_source + + @staticmethod + def _get_api_endpoint( + api_override, client_cert_source, universe_domain, use_mtls_endpoint + ): + """Return the API endpoint used by the client. + + Args: + api_override (str): The API endpoint override. If specified, this is always + the return value of this function and the other arguments are not used. + client_cert_source (bytes): The client certificate source used by the client. + universe_domain (str): The universe domain used by the client. + use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters. + Possible values are "always", "auto", or "never". + + Returns: + str: The API endpoint to be used by the client. + """ + if api_override is not None: + api_endpoint = api_override + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + _default_universe = CacheServiceClient._DEFAULT_UNIVERSE + if universe_domain != _default_universe: + raise MutualTLSChannelError( + f"mTLS is not supported in any universe other than {_default_universe}." + ) + api_endpoint = CacheServiceClient.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = CacheServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=universe_domain + ) + return api_endpoint + + @staticmethod + def _get_universe_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] + ) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = CacheServiceClient._DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise ValueError("Universe Domain cannot be an empty string.") + return universe_domain + + def _validate_universe_domain(self): + """Validates client's and credentials' universe domains are consistent. + + Returns: + bool: True iff the configured universe domain is valid. + + Raises: + ValueError: If the configured universe domain is not valid. + """ + + # NOTE (b/349488459): universe validation is disabled until further notice. + return True + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used by the client instance. + """ + return self._universe_domain + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[str, CacheServiceTransport, Callable[..., CacheServiceTransport]] + ] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the cache service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,CacheServiceTransport,Callable[..., CacheServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the CacheServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that the ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client_options = client_options + if isinstance(self._client_options, dict): + self._client_options = client_options_lib.from_dict(self._client_options) + if self._client_options is None: + self._client_options = client_options_lib.ClientOptions() + self._client_options = cast( + client_options_lib.ClientOptions, self._client_options + ) + + universe_domain_opt = getattr(self._client_options, "universe_domain", None) + + ( + self._use_client_cert, + self._use_mtls_endpoint, + self._universe_domain_env, + ) = CacheServiceClient._read_environment_variables() + self._client_cert_source = CacheServiceClient._get_client_cert_source( + self._client_options.client_cert_source, self._use_client_cert + ) + self._universe_domain = CacheServiceClient._get_universe_domain( + universe_domain_opt, self._universe_domain_env + ) + self._api_endpoint = None # updated below, depending on `transport` + + # Initialize the universe domain validation. + self._is_universe_domain_valid = False + + if CLIENT_LOGGING_SUPPORTED: # pragma: NO COVER + # Setup logging. + client_logging.initialize_logging() + + api_key_value = getattr(self._client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + transport_provided = isinstance(transport, CacheServiceTransport) + if transport_provided: + # transport is a CacheServiceTransport instance. + if credentials or self._client_options.credentials_file or api_key_value: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if self._client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = cast(CacheServiceTransport, transport) + self._api_endpoint = self._transport.host + + self._api_endpoint = self._api_endpoint or CacheServiceClient._get_api_endpoint( + self._client_options.api_endpoint, + self._client_cert_source, + self._universe_domain, + self._use_mtls_endpoint, + ) + + if not transport_provided: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + + transport_init: Union[ + Type[CacheServiceTransport], Callable[..., CacheServiceTransport] + ] = ( + CacheServiceClient.get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., CacheServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( + credentials=credentials, + credentials_file=self._client_options.credentials_file, + host=self._api_endpoint, + scopes=self._client_options.scopes, + client_cert_source_for_mtls=self._client_cert_source, + quota_project_id=self._client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=self._client_options.api_audience, + ) + + if "async" not in str(self._transport): + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.CacheServiceClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "universeDomain": getattr( + self._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._transport._credentials).__module__}.{type(self._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "credentialsType": None, + }, + ) + + def list_cached_contents( + self, + request: Optional[Union[cache_service.ListCachedContentsRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListCachedContentsPager: + r"""Lists CachedContents. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_list_cached_contents(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListCachedContentsRequest( + ) + + # Make the request + page_result = client.list_cached_contents(request=request) + + # Handle the response + for response in page_result: + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.ListCachedContentsRequest, dict]): + The request object. Request to list CachedContents. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.cache_service.pagers.ListCachedContentsPager: + Response with CachedContents list. + + Iterating over this object will yield + results and resolve additional pages + automatically. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, cache_service.ListCachedContentsRequest): + request = cache_service.ListCachedContentsRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_cached_contents] + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListCachedContentsPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def create_cached_content( + self, + request: Optional[Union[cache_service.CreateCachedContentRequest, dict]] = None, + *, + cached_content: Optional[gag_cached_content.CachedContent] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gag_cached_content.CachedContent: + r"""Creates CachedContent resource. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_create_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateCachedContentRequest( + ) + + # Make the request + response = client.create_cached_content(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.CreateCachedContentRequest, dict]): + The request object. Request to create CachedContent. + cached_content (google.ai.generativelanguage_v1alpha.types.CachedContent): + Required. The cached content to + create. + + This corresponds to the ``cached_content`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.CachedContent: + Content that has been preprocessed + and can be used in subsequent request to + GenerativeService. + + Cached content can be only used with + model it was created for. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([cached_content]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, cache_service.CreateCachedContentRequest): + request = cache_service.CreateCachedContentRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if cached_content is not None: + request.cached_content = cached_content + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_cached_content] + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_cached_content( + self, + request: Optional[Union[cache_service.GetCachedContentRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> cached_content.CachedContent: + r"""Reads CachedContent resource. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_get_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetCachedContentRequest( + name="name_value", + ) + + # Make the request + response = client.get_cached_content(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.GetCachedContentRequest, dict]): + The request object. Request to read CachedContent. + name (str): + Required. The resource name referring to the content + cache entry. Format: ``cachedContents/{id}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.CachedContent: + Content that has been preprocessed + and can be used in subsequent request to + GenerativeService. + + Cached content can be only used with + model it was created for. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, cache_service.GetCachedContentRequest): + request = cache_service.GetCachedContentRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_cached_content] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def update_cached_content( + self, + request: Optional[Union[cache_service.UpdateCachedContentRequest, dict]] = None, + *, + cached_content: Optional[gag_cached_content.CachedContent] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gag_cached_content.CachedContent: + r"""Updates CachedContent resource (only expiration is + updatable). + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_update_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateCachedContentRequest( + ) + + # Make the request + response = client.update_cached_content(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.UpdateCachedContentRequest, dict]): + The request object. Request to update CachedContent. + cached_content (google.ai.generativelanguage_v1alpha.types.CachedContent): + Required. The content cache entry to + update + + This corresponds to the ``cached_content`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + The list of fields to update. + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.CachedContent: + Content that has been preprocessed + and can be used in subsequent request to + GenerativeService. + + Cached content can be only used with + model it was created for. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([cached_content, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, cache_service.UpdateCachedContentRequest): + request = cache_service.UpdateCachedContentRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if cached_content is not None: + request.cached_content = cached_content + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_cached_content] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("cached_content.name", request.cached_content.name),) + ), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_cached_content( + self, + request: Optional[Union[cache_service.DeleteCachedContentRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Deletes CachedContent resource. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_delete_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteCachedContentRequest( + name="name_value", + ) + + # Make the request + client.delete_cached_content(request=request) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.DeleteCachedContentRequest, dict]): + The request object. Request to delete CachedContent. + name (str): + Required. The resource name referring to the content + cache entry Format: ``cachedContents/{id}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, cache_service.DeleteCachedContentRequest): + request = cache_service.DeleteCachedContentRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_cached_content] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def __enter__(self) -> "CacheServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("CacheServiceClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/pagers.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/pagers.py new file mode 100644 index 000000000000..0a43aadf2d53 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/pagers.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Iterator, + Optional, + Sequence, + Tuple, + Union, +) + +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + +from google.ai.generativelanguage_v1alpha.types import cache_service, cached_content + + +class ListCachedContentsPager: + """A pager for iterating through ``list_cached_contents`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListCachedContentsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``cached_contents`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListCachedContents`` requests and continue to iterate + through the ``cached_contents`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListCachedContentsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., cache_service.ListCachedContentsResponse], + request: cache_service.ListCachedContentsRequest, + response: cache_service.ListCachedContentsResponse, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListCachedContentsRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListCachedContentsResponse): + The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = cache_service.ListCachedContentsRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[cache_service.ListCachedContentsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __iter__(self) -> Iterator[cached_content.CachedContent]: + for page in self.pages: + yield from page.cached_contents + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListCachedContentsAsyncPager: + """A pager for iterating through ``list_cached_contents`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListCachedContentsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``cached_contents`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListCachedContents`` requests and continue to iterate + through the ``cached_contents`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListCachedContentsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[cache_service.ListCachedContentsResponse]], + request: cache_service.ListCachedContentsRequest, + response: cache_service.ListCachedContentsResponse, + *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListCachedContentsRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListCachedContentsResponse): + The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = cache_service.ListCachedContentsRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[cache_service.ListCachedContentsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __aiter__(self) -> AsyncIterator[cached_content.CachedContent]: + async def async_generator(): + async for page in self.pages: + for response in page.cached_contents: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/README.rst b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/README.rst new file mode 100644 index 000000000000..8647ac1a9f9b --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/README.rst @@ -0,0 +1,9 @@ + +transport inheritance structure +_______________________________ + +`CacheServiceTransport` is the ABC for all transports. +- public child `CacheServiceGrpcTransport` for sync gRPC transport (defined in `grpc.py`). +- public child `CacheServiceGrpcAsyncIOTransport` for async gRPC transport (defined in `grpc_asyncio.py`). +- private child `_BaseCacheServiceRestTransport` for base REST transport with inner classes `_BaseMETHOD` (defined in `rest_base.py`). +- public child `CacheServiceRestTransport` for sync REST transport with inner classes `METHOD` derived from the parent's corresponding `_BaseMETHOD` classes (defined in `rest.py`). diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/__init__.py new file mode 100644 index 000000000000..cef091cd23ab --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/__init__.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import CacheServiceTransport +from .grpc import CacheServiceGrpcTransport +from .grpc_asyncio import CacheServiceGrpcAsyncIOTransport +from .rest import CacheServiceRestInterceptor, CacheServiceRestTransport + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[CacheServiceTransport]] +_transport_registry["grpc"] = CacheServiceGrpcTransport +_transport_registry["grpc_asyncio"] = CacheServiceGrpcAsyncIOTransport +_transport_registry["rest"] = CacheServiceRestTransport + +__all__ = ( + "CacheServiceTransport", + "CacheServiceGrpcTransport", + "CacheServiceGrpcAsyncIOTransport", + "CacheServiceRestTransport", + "CacheServiceRestInterceptor", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/base.py new file mode 100644 index 000000000000..ee82d216203c --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/base.py @@ -0,0 +1,263 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account # type: ignore +from google.protobuf import empty_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version +from google.ai.generativelanguage_v1alpha.types import ( + cached_content as gag_cached_content, +) +from google.ai.generativelanguage_v1alpha.types import cache_service +from google.ai.generativelanguage_v1alpha.types import cached_content + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +class CacheServiceTransport(abc.ABC): + """Abstract transport class for CacheService.""" + + AUTH_SCOPES = () + + DEFAULT_HOST: str = "generativelanguage.googleapis.com" + + def __init__( + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + if not hasattr(self, "_ignore_credentials"): + self._ignore_credentials: bool = False + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) + elif credentials is None and not self._ignore_credentials: + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + @property + def host(self): + return self._host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.list_cached_contents: gapic_v1.method.wrap_method( + self.list_cached_contents, + default_timeout=None, + client_info=client_info, + ), + self.create_cached_content: gapic_v1.method.wrap_method( + self.create_cached_content, + default_timeout=None, + client_info=client_info, + ), + self.get_cached_content: gapic_v1.method.wrap_method( + self.get_cached_content, + default_timeout=None, + client_info=client_info, + ), + self.update_cached_content: gapic_v1.method.wrap_method( + self.update_cached_content, + default_timeout=None, + client_info=client_info, + ), + self.delete_cached_content: gapic_v1.method.wrap_method( + self.delete_cached_content, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: gapic_v1.method.wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: gapic_v1.method.wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def list_cached_contents( + self, + ) -> Callable[ + [cache_service.ListCachedContentsRequest], + Union[ + cache_service.ListCachedContentsResponse, + Awaitable[cache_service.ListCachedContentsResponse], + ], + ]: + raise NotImplementedError() + + @property + def create_cached_content( + self, + ) -> Callable[ + [cache_service.CreateCachedContentRequest], + Union[ + gag_cached_content.CachedContent, + Awaitable[gag_cached_content.CachedContent], + ], + ]: + raise NotImplementedError() + + @property + def get_cached_content( + self, + ) -> Callable[ + [cache_service.GetCachedContentRequest], + Union[cached_content.CachedContent, Awaitable[cached_content.CachedContent]], + ]: + raise NotImplementedError() + + @property + def update_cached_content( + self, + ) -> Callable[ + [cache_service.UpdateCachedContentRequest], + Union[ + gag_cached_content.CachedContent, + Awaitable[gag_cached_content.CachedContent], + ], + ]: + raise NotImplementedError() + + @property + def delete_cached_content( + self, + ) -> Callable[ + [cache_service.DeleteCachedContentRequest], + Union[empty_pb2.Empty, Awaitable[empty_pb2.Empty]], + ]: + raise NotImplementedError() + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], + Union[ + operations_pb2.ListOperationsResponse, + Awaitable[operations_pb2.ListOperationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_operation( + self, + ) -> Callable[ + [operations_pb2.GetOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ("CacheServiceTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/grpc.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/grpc.py new file mode 100644 index 000000000000..c337588e39d7 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/grpc.py @@ -0,0 +1,517 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import logging as std_logging +import pickle +from typing import Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import gapic_v1, grpc_helpers +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import ( + cached_content as gag_cached_content, +) +from google.ai.generativelanguage_v1alpha.types import cache_service +from google.ai.generativelanguage_v1alpha.types import cached_content + +from .base import DEFAULT_CLIENT_INFO, CacheServiceTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientInterceptor(grpc.UnaryUnaryClientInterceptor): # pragma: NO COVER + def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": client_call_details.method, + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + + response = continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = response.result() + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response for {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": client_call_details.method, + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class CacheServiceGrpcTransport(CacheServiceTransport): + """gRPC backend transport for CacheService. + + API for managing cache of content (CachedContent resources) + that can be used in GenerativeService requests. This way + generate content requests can benefit from preprocessing work + being done earlier, possibly lowering their computational cost. + It is intended to be used with large contexts. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, grpc.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientInterceptor() + self._logged_channel = grpc.intercept_channel( + self._grpc_channel, self._interceptor + ) + + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service.""" + return self._grpc_channel + + @property + def list_cached_contents( + self, + ) -> Callable[ + [cache_service.ListCachedContentsRequest], + cache_service.ListCachedContentsResponse, + ]: + r"""Return a callable for the list cached contents method over gRPC. + + Lists CachedContents. + + Returns: + Callable[[~.ListCachedContentsRequest], + ~.ListCachedContentsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_cached_contents" not in self._stubs: + self._stubs["list_cached_contents"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.CacheService/ListCachedContents", + request_serializer=cache_service.ListCachedContentsRequest.serialize, + response_deserializer=cache_service.ListCachedContentsResponse.deserialize, + ) + return self._stubs["list_cached_contents"] + + @property + def create_cached_content( + self, + ) -> Callable[ + [cache_service.CreateCachedContentRequest], gag_cached_content.CachedContent + ]: + r"""Return a callable for the create cached content method over gRPC. + + Creates CachedContent resource. + + Returns: + Callable[[~.CreateCachedContentRequest], + ~.CachedContent]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_cached_content" not in self._stubs: + self._stubs["create_cached_content"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.CacheService/CreateCachedContent", + request_serializer=cache_service.CreateCachedContentRequest.serialize, + response_deserializer=gag_cached_content.CachedContent.deserialize, + ) + return self._stubs["create_cached_content"] + + @property + def get_cached_content( + self, + ) -> Callable[ + [cache_service.GetCachedContentRequest], cached_content.CachedContent + ]: + r"""Return a callable for the get cached content method over gRPC. + + Reads CachedContent resource. + + Returns: + Callable[[~.GetCachedContentRequest], + ~.CachedContent]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_cached_content" not in self._stubs: + self._stubs["get_cached_content"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.CacheService/GetCachedContent", + request_serializer=cache_service.GetCachedContentRequest.serialize, + response_deserializer=cached_content.CachedContent.deserialize, + ) + return self._stubs["get_cached_content"] + + @property + def update_cached_content( + self, + ) -> Callable[ + [cache_service.UpdateCachedContentRequest], gag_cached_content.CachedContent + ]: + r"""Return a callable for the update cached content method over gRPC. + + Updates CachedContent resource (only expiration is + updatable). + + Returns: + Callable[[~.UpdateCachedContentRequest], + ~.CachedContent]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_cached_content" not in self._stubs: + self._stubs["update_cached_content"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.CacheService/UpdateCachedContent", + request_serializer=cache_service.UpdateCachedContentRequest.serialize, + response_deserializer=gag_cached_content.CachedContent.deserialize, + ) + return self._stubs["update_cached_content"] + + @property + def delete_cached_content( + self, + ) -> Callable[[cache_service.DeleteCachedContentRequest], empty_pb2.Empty]: + r"""Return a callable for the delete cached content method over gRPC. + + Deletes CachedContent resource. + + Returns: + Callable[[~.DeleteCachedContentRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_cached_content" not in self._stubs: + self._stubs["delete_cached_content"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.CacheService/DeleteCachedContent", + request_serializer=cache_service.DeleteCachedContentRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["delete_cached_content"] + + def close(self): + self._logged_channel.close() + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ("CacheServiceGrpcTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/grpc_asyncio.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..764e2a1cb29f --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/grpc_asyncio.py @@ -0,0 +1,573 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import json +import logging as std_logging +import pickle +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, grpc_helpers_async +from google.api_core import retry_async as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +from grpc.experimental import aio # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import ( + cached_content as gag_cached_content, +) +from google.ai.generativelanguage_v1alpha.types import cache_service +from google.ai.generativelanguage_v1alpha.types import cached_content + +from .base import DEFAULT_CLIENT_INFO, CacheServiceTransport +from .grpc import CacheServiceGrpcTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientAIOInterceptor( + grpc.aio.UnaryUnaryClientInterceptor +): # pragma: NO COVER + async def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": str(client_call_details.method), + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + response = await continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = await response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = await response + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response to rpc {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": str(client_call_details.method), + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class CacheServiceGrpcAsyncIOTransport(CacheServiceTransport): + """gRPC AsyncIO backend transport for CacheService. + + API for managing cache of content (CachedContent resources) + that can be used in GenerativeService requests. This way + generate content requests can benefit from preprocessing work + being done earlier, possibly lowering their computational cost. + It is intended to be used with large contexts. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, aio.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientAIOInterceptor() + self._grpc_channel._unary_unary_interceptors.append(self._interceptor) + self._logged_channel = self._grpc_channel + self._wrap_with_kind = ( + "kind" in inspect.signature(gapic_v1.method_async.wrap_method).parameters + ) + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def list_cached_contents( + self, + ) -> Callable[ + [cache_service.ListCachedContentsRequest], + Awaitable[cache_service.ListCachedContentsResponse], + ]: + r"""Return a callable for the list cached contents method over gRPC. + + Lists CachedContents. + + Returns: + Callable[[~.ListCachedContentsRequest], + Awaitable[~.ListCachedContentsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_cached_contents" not in self._stubs: + self._stubs["list_cached_contents"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.CacheService/ListCachedContents", + request_serializer=cache_service.ListCachedContentsRequest.serialize, + response_deserializer=cache_service.ListCachedContentsResponse.deserialize, + ) + return self._stubs["list_cached_contents"] + + @property + def create_cached_content( + self, + ) -> Callable[ + [cache_service.CreateCachedContentRequest], + Awaitable[gag_cached_content.CachedContent], + ]: + r"""Return a callable for the create cached content method over gRPC. + + Creates CachedContent resource. + + Returns: + Callable[[~.CreateCachedContentRequest], + Awaitable[~.CachedContent]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_cached_content" not in self._stubs: + self._stubs["create_cached_content"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.CacheService/CreateCachedContent", + request_serializer=cache_service.CreateCachedContentRequest.serialize, + response_deserializer=gag_cached_content.CachedContent.deserialize, + ) + return self._stubs["create_cached_content"] + + @property + def get_cached_content( + self, + ) -> Callable[ + [cache_service.GetCachedContentRequest], Awaitable[cached_content.CachedContent] + ]: + r"""Return a callable for the get cached content method over gRPC. + + Reads CachedContent resource. + + Returns: + Callable[[~.GetCachedContentRequest], + Awaitable[~.CachedContent]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_cached_content" not in self._stubs: + self._stubs["get_cached_content"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.CacheService/GetCachedContent", + request_serializer=cache_service.GetCachedContentRequest.serialize, + response_deserializer=cached_content.CachedContent.deserialize, + ) + return self._stubs["get_cached_content"] + + @property + def update_cached_content( + self, + ) -> Callable[ + [cache_service.UpdateCachedContentRequest], + Awaitable[gag_cached_content.CachedContent], + ]: + r"""Return a callable for the update cached content method over gRPC. + + Updates CachedContent resource (only expiration is + updatable). + + Returns: + Callable[[~.UpdateCachedContentRequest], + Awaitable[~.CachedContent]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_cached_content" not in self._stubs: + self._stubs["update_cached_content"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.CacheService/UpdateCachedContent", + request_serializer=cache_service.UpdateCachedContentRequest.serialize, + response_deserializer=gag_cached_content.CachedContent.deserialize, + ) + return self._stubs["update_cached_content"] + + @property + def delete_cached_content( + self, + ) -> Callable[ + [cache_service.DeleteCachedContentRequest], Awaitable[empty_pb2.Empty] + ]: + r"""Return a callable for the delete cached content method over gRPC. + + Deletes CachedContent resource. + + Returns: + Callable[[~.DeleteCachedContentRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_cached_content" not in self._stubs: + self._stubs["delete_cached_content"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.CacheService/DeleteCachedContent", + request_serializer=cache_service.DeleteCachedContentRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["delete_cached_content"] + + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.list_cached_contents: self._wrap_method( + self.list_cached_contents, + default_timeout=None, + client_info=client_info, + ), + self.create_cached_content: self._wrap_method( + self.create_cached_content, + default_timeout=None, + client_info=client_info, + ), + self.get_cached_content: self._wrap_method( + self.get_cached_content, + default_timeout=None, + client_info=client_info, + ), + self.update_cached_content: self._wrap_method( + self.update_cached_content, + default_timeout=None, + client_info=client_info, + ), + self.delete_cached_content: self._wrap_method( + self.delete_cached_content, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: self._wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: self._wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def _wrap_method(self, func, *args, **kwargs): + if self._wrap_with_kind: # pragma: NO COVER + kwargs["kind"] = self.kind + return gapic_v1.method_async.wrap_method(func, *args, **kwargs) + + def close(self): + return self._logged_channel.close() + + @property + def kind(self) -> str: + return "grpc_asyncio" + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + +__all__ = ("CacheServiceGrpcAsyncIOTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/rest.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/rest.py new file mode 100644 index 000000000000..e89f49ec8994 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/rest.py @@ -0,0 +1,1428 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +import json # type: ignore +import logging +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, rest_helpers, rest_streaming +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.requests import AuthorizedSession # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import json_format +from requests import __version__ as requests_version + +from google.ai.generativelanguage_v1alpha.types import ( + cached_content as gag_cached_content, +) +from google.ai.generativelanguage_v1alpha.types import cache_service +from google.ai.generativelanguage_v1alpha.types import cached_content + +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from .rest_base import _BaseCacheServiceRestTransport + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = logging.getLogger(__name__) + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=f"requests@{requests_version}", +) + + +class CacheServiceRestInterceptor: + """Interceptor for CacheService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the CacheServiceRestTransport. + + .. code-block:: python + class MyCustomCacheServiceInterceptor(CacheServiceRestInterceptor): + def pre_create_cached_content(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_create_cached_content(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_delete_cached_content(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def pre_get_cached_content(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_cached_content(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_cached_contents(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_cached_contents(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_update_cached_content(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_update_cached_content(self, response): + logging.log(f"Received response: {response}") + return response + + transport = CacheServiceRestTransport(interceptor=MyCustomCacheServiceInterceptor()) + client = CacheServiceClient(transport=transport) + + + """ + + def pre_create_cached_content( + self, + request: cache_service.CreateCachedContentRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + cache_service.CreateCachedContentRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for create_cached_content + + Override in a subclass to manipulate the request or metadata + before they are sent to the CacheService server. + """ + return request, metadata + + def post_create_cached_content( + self, response: gag_cached_content.CachedContent + ) -> gag_cached_content.CachedContent: + """Post-rpc interceptor for create_cached_content + + Override in a subclass to manipulate the response + after it is returned by the CacheService server but before + it is returned to user code. + """ + return response + + def pre_delete_cached_content( + self, + request: cache_service.DeleteCachedContentRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + cache_service.DeleteCachedContentRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for delete_cached_content + + Override in a subclass to manipulate the request or metadata + before they are sent to the CacheService server. + """ + return request, metadata + + def pre_get_cached_content( + self, + request: cache_service.GetCachedContentRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + cache_service.GetCachedContentRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for get_cached_content + + Override in a subclass to manipulate the request or metadata + before they are sent to the CacheService server. + """ + return request, metadata + + def post_get_cached_content( + self, response: cached_content.CachedContent + ) -> cached_content.CachedContent: + """Post-rpc interceptor for get_cached_content + + Override in a subclass to manipulate the response + after it is returned by the CacheService server but before + it is returned to user code. + """ + return response + + def pre_list_cached_contents( + self, + request: cache_service.ListCachedContentsRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + cache_service.ListCachedContentsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for list_cached_contents + + Override in a subclass to manipulate the request or metadata + before they are sent to the CacheService server. + """ + return request, metadata + + def post_list_cached_contents( + self, response: cache_service.ListCachedContentsResponse + ) -> cache_service.ListCachedContentsResponse: + """Post-rpc interceptor for list_cached_contents + + Override in a subclass to manipulate the response + after it is returned by the CacheService server but before + it is returned to user code. + """ + return response + + def pre_update_cached_content( + self, + request: cache_service.UpdateCachedContentRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + cache_service.UpdateCachedContentRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for update_cached_content + + Override in a subclass to manipulate the request or metadata + before they are sent to the CacheService server. + """ + return request, metadata + + def post_update_cached_content( + self, response: gag_cached_content.CachedContent + ) -> gag_cached_content.CachedContent: + """Post-rpc interceptor for update_cached_content + + Override in a subclass to manipulate the response + after it is returned by the CacheService server but before + it is returned to user code. + """ + return response + + def pre_get_operation( + self, + request: operations_pb2.GetOperationRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.GetOperationRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for get_operation + + Override in a subclass to manipulate the request or metadata + before they are sent to the CacheService server. + """ + return request, metadata + + def post_get_operation( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for get_operation + + Override in a subclass to manipulate the response + after it is returned by the CacheService server but before + it is returned to user code. + """ + return response + + def pre_list_operations( + self, + request: operations_pb2.ListOperationsRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.ListOperationsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for list_operations + + Override in a subclass to manipulate the request or metadata + before they are sent to the CacheService server. + """ + return request, metadata + + def post_list_operations( + self, response: operations_pb2.ListOperationsResponse + ) -> operations_pb2.ListOperationsResponse: + """Post-rpc interceptor for list_operations + + Override in a subclass to manipulate the response + after it is returned by the CacheService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class CacheServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: CacheServiceRestInterceptor + + +class CacheServiceRestTransport(_BaseCacheServiceRestTransport): + """REST backend synchronous transport for CacheService. + + API for managing cache of content (CachedContent resources) + that can be used in GenerativeService requests. This way + generate content requests can benefit from preprocessing work + being done earlier, possibly lowering their computational cost. + It is intended to be used with large contexts. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[CacheServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + url_scheme=url_scheme, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or CacheServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _CreateCachedContent( + _BaseCacheServiceRestTransport._BaseCreateCachedContent, CacheServiceRestStub + ): + def __hash__(self): + return hash("CacheServiceRestTransport.CreateCachedContent") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: cache_service.CreateCachedContentRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gag_cached_content.CachedContent: + r"""Call the create cached content method over HTTP. + + Args: + request (~.cache_service.CreateCachedContentRequest): + The request object. Request to create CachedContent. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.gag_cached_content.CachedContent: + Content that has been preprocessed + and can be used in subsequent request to + GenerativeService. + + Cached content can be only used with + model it was created for. + + """ + + http_options = ( + _BaseCacheServiceRestTransport._BaseCreateCachedContent._get_http_options() + ) + + request, metadata = self._interceptor.pre_create_cached_content( + request, metadata + ) + transcoded_request = _BaseCacheServiceRestTransport._BaseCreateCachedContent._get_transcoded_request( + http_options, request + ) + + body = _BaseCacheServiceRestTransport._BaseCreateCachedContent._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseCacheServiceRestTransport._BaseCreateCachedContent._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.CacheServiceClient.CreateCachedContent", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": "CreateCachedContent", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = CacheServiceRestTransport._CreateCachedContent._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = gag_cached_content.CachedContent() + pb_resp = gag_cached_content.CachedContent.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_create_cached_content(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = gag_cached_content.CachedContent.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.CacheServiceClient.create_cached_content", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": "CreateCachedContent", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _DeleteCachedContent( + _BaseCacheServiceRestTransport._BaseDeleteCachedContent, CacheServiceRestStub + ): + def __hash__(self): + return hash("CacheServiceRestTransport.DeleteCachedContent") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: cache_service.DeleteCachedContentRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ): + r"""Call the delete cached content method over HTTP. + + Args: + request (~.cache_service.DeleteCachedContentRequest): + The request object. Request to delete CachedContent. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + + http_options = ( + _BaseCacheServiceRestTransport._BaseDeleteCachedContent._get_http_options() + ) + + request, metadata = self._interceptor.pre_delete_cached_content( + request, metadata + ) + transcoded_request = _BaseCacheServiceRestTransport._BaseDeleteCachedContent._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseCacheServiceRestTransport._BaseDeleteCachedContent._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.CacheServiceClient.DeleteCachedContent", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": "DeleteCachedContent", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = CacheServiceRestTransport._DeleteCachedContent._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + class _GetCachedContent( + _BaseCacheServiceRestTransport._BaseGetCachedContent, CacheServiceRestStub + ): + def __hash__(self): + return hash("CacheServiceRestTransport.GetCachedContent") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: cache_service.GetCachedContentRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> cached_content.CachedContent: + r"""Call the get cached content method over HTTP. + + Args: + request (~.cache_service.GetCachedContentRequest): + The request object. Request to read CachedContent. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.cached_content.CachedContent: + Content that has been preprocessed + and can be used in subsequent request to + GenerativeService. + + Cached content can be only used with + model it was created for. + + """ + + http_options = ( + _BaseCacheServiceRestTransport._BaseGetCachedContent._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_cached_content( + request, metadata + ) + transcoded_request = _BaseCacheServiceRestTransport._BaseGetCachedContent._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseCacheServiceRestTransport._BaseGetCachedContent._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.CacheServiceClient.GetCachedContent", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": "GetCachedContent", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = CacheServiceRestTransport._GetCachedContent._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = cached_content.CachedContent() + pb_resp = cached_content.CachedContent.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_get_cached_content(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = cached_content.CachedContent.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.CacheServiceClient.get_cached_content", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": "GetCachedContent", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _ListCachedContents( + _BaseCacheServiceRestTransport._BaseListCachedContents, CacheServiceRestStub + ): + def __hash__(self): + return hash("CacheServiceRestTransport.ListCachedContents") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: cache_service.ListCachedContentsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> cache_service.ListCachedContentsResponse: + r"""Call the list cached contents method over HTTP. + + Args: + request (~.cache_service.ListCachedContentsRequest): + The request object. Request to list CachedContents. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.cache_service.ListCachedContentsResponse: + Response with CachedContents list. + """ + + http_options = ( + _BaseCacheServiceRestTransport._BaseListCachedContents._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_cached_contents( + request, metadata + ) + transcoded_request = _BaseCacheServiceRestTransport._BaseListCachedContents._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseCacheServiceRestTransport._BaseListCachedContents._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.CacheServiceClient.ListCachedContents", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": "ListCachedContents", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = CacheServiceRestTransport._ListCachedContents._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = cache_service.ListCachedContentsResponse() + pb_resp = cache_service.ListCachedContentsResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_list_cached_contents(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = cache_service.ListCachedContentsResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.CacheServiceClient.list_cached_contents", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": "ListCachedContents", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _UpdateCachedContent( + _BaseCacheServiceRestTransport._BaseUpdateCachedContent, CacheServiceRestStub + ): + def __hash__(self): + return hash("CacheServiceRestTransport.UpdateCachedContent") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: cache_service.UpdateCachedContentRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gag_cached_content.CachedContent: + r"""Call the update cached content method over HTTP. + + Args: + request (~.cache_service.UpdateCachedContentRequest): + The request object. Request to update CachedContent. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.gag_cached_content.CachedContent: + Content that has been preprocessed + and can be used in subsequent request to + GenerativeService. + + Cached content can be only used with + model it was created for. + + """ + + http_options = ( + _BaseCacheServiceRestTransport._BaseUpdateCachedContent._get_http_options() + ) + + request, metadata = self._interceptor.pre_update_cached_content( + request, metadata + ) + transcoded_request = _BaseCacheServiceRestTransport._BaseUpdateCachedContent._get_transcoded_request( + http_options, request + ) + + body = _BaseCacheServiceRestTransport._BaseUpdateCachedContent._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseCacheServiceRestTransport._BaseUpdateCachedContent._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.CacheServiceClient.UpdateCachedContent", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": "UpdateCachedContent", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = CacheServiceRestTransport._UpdateCachedContent._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = gag_cached_content.CachedContent() + pb_resp = gag_cached_content.CachedContent.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_update_cached_content(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = gag_cached_content.CachedContent.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.CacheServiceClient.update_cached_content", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": "UpdateCachedContent", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + @property + def create_cached_content( + self, + ) -> Callable[ + [cache_service.CreateCachedContentRequest], gag_cached_content.CachedContent + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CreateCachedContent(self._session, self._host, self._interceptor) # type: ignore + + @property + def delete_cached_content( + self, + ) -> Callable[[cache_service.DeleteCachedContentRequest], empty_pb2.Empty]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DeleteCachedContent(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_cached_content( + self, + ) -> Callable[ + [cache_service.GetCachedContentRequest], cached_content.CachedContent + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetCachedContent(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_cached_contents( + self, + ) -> Callable[ + [cache_service.ListCachedContentsRequest], + cache_service.ListCachedContentsResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListCachedContents(self._session, self._host, self._interceptor) # type: ignore + + @property + def update_cached_content( + self, + ) -> Callable[ + [cache_service.UpdateCachedContentRequest], gag_cached_content.CachedContent + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._UpdateCachedContent(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_operation(self): + return self._GetOperation(self._session, self._host, self._interceptor) # type: ignore + + class _GetOperation( + _BaseCacheServiceRestTransport._BaseGetOperation, CacheServiceRestStub + ): + def __hash__(self): + return hash("CacheServiceRestTransport.GetOperation") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.GetOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Call the get operation method over HTTP. + + Args: + request (operations_pb2.GetOperationRequest): + The request object for GetOperation method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.Operation: Response from GetOperation method. + """ + + http_options = ( + _BaseCacheServiceRestTransport._BaseGetOperation._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_operation(request, metadata) + transcoded_request = _BaseCacheServiceRestTransport._BaseGetOperation._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = ( + _BaseCacheServiceRestTransport._BaseGetOperation._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.CacheServiceClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": "GetOperation", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = CacheServiceRestTransport._GetOperation._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.Operation() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_get_operation(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.CacheServiceAsyncClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": "GetOperation", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def list_operations(self): + return self._ListOperations(self._session, self._host, self._interceptor) # type: ignore + + class _ListOperations( + _BaseCacheServiceRestTransport._BaseListOperations, CacheServiceRestStub + ): + def __hash__(self): + return hash("CacheServiceRestTransport.ListOperations") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.ListOperationsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Call the list operations method over HTTP. + + Args: + request (operations_pb2.ListOperationsRequest): + The request object for ListOperations method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.ListOperationsResponse: Response from ListOperations method. + """ + + http_options = ( + _BaseCacheServiceRestTransport._BaseListOperations._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_operations(request, metadata) + transcoded_request = _BaseCacheServiceRestTransport._BaseListOperations._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseCacheServiceRestTransport._BaseListOperations._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.CacheServiceClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": "ListOperations", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = CacheServiceRestTransport._ListOperations._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.ListOperationsResponse() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_list_operations(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.CacheServiceAsyncClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.CacheService", + "rpcName": "ListOperations", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("CacheServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/rest_base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/rest_base.py new file mode 100644 index 000000000000..cb05981f3fdd --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/cache_service/transports/rest_base.py @@ -0,0 +1,399 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json # type: ignore +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1, path_template +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import json_format + +from google.ai.generativelanguage_v1alpha.types import ( + cached_content as gag_cached_content, +) +from google.ai.generativelanguage_v1alpha.types import cache_service +from google.ai.generativelanguage_v1alpha.types import cached_content + +from .base import DEFAULT_CLIENT_INFO, CacheServiceTransport + + +class _BaseCacheServiceRestTransport(CacheServiceTransport): + """Base REST backend transport for CacheService. + + Note: This class is not meant to be used directly. Use its sync and + async sub-classes instead. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[Any] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[Any]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + class _BaseCreateCachedContent: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/cachedContents", + "body": "cached_content", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = cache_service.CreateCachedContentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseCacheServiceRestTransport._BaseCreateCachedContent._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseDeleteCachedContent: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1alpha/{name=cachedContents/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = cache_service.DeleteCachedContentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseCacheServiceRestTransport._BaseDeleteCachedContent._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetCachedContent: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=cachedContents/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = cache_service.GetCachedContentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseCacheServiceRestTransport._BaseGetCachedContent._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseListCachedContents: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/cachedContents", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = cache_service.ListCachedContentsRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseUpdateCachedContent: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "patch", + "uri": "/v1alpha/{cached_content.name=cachedContents/*}", + "body": "cached_content", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = cache_service.UpdateCachedContentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseCacheServiceRestTransport._BaseUpdateCachedContent._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetOperation: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=generatedFiles/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*/operations/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + class _BaseListOperations: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*}/operations", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*}/operations", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + +__all__ = ("_BaseCacheServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/__init__.py new file mode 100644 index 000000000000..0f3a84e6ba34 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .async_client import DiscussServiceAsyncClient +from .client import DiscussServiceClient + +__all__ = ( + "DiscussServiceClient", + "DiscussServiceAsyncClient", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/async_client.py new file mode 100644 index 000000000000..badd7396f4a8 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/async_client.py @@ -0,0 +1,740 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry_async as retries +from google.api_core.client_options import ClientOptions +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore + +from google.longrunning import operations_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.types import discuss_service, safety + +from .client import DiscussServiceClient +from .transports.base import DEFAULT_CLIENT_INFO, DiscussServiceTransport +from .transports.grpc_asyncio import DiscussServiceGrpcAsyncIOTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class DiscussServiceAsyncClient: + """An API for using Generative Language Models (GLMs) in dialog + applications. + Also known as large language models (LLMs), this API provides + models that are trained for multi-turn dialog. + """ + + _client: DiscussServiceClient + + # Copy defaults from the synchronous client for use here. + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = DiscussServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = DiscussServiceClient.DEFAULT_MTLS_ENDPOINT + _DEFAULT_ENDPOINT_TEMPLATE = DiscussServiceClient._DEFAULT_ENDPOINT_TEMPLATE + _DEFAULT_UNIVERSE = DiscussServiceClient._DEFAULT_UNIVERSE + + model_path = staticmethod(DiscussServiceClient.model_path) + parse_model_path = staticmethod(DiscussServiceClient.parse_model_path) + common_billing_account_path = staticmethod( + DiscussServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + DiscussServiceClient.parse_common_billing_account_path + ) + common_folder_path = staticmethod(DiscussServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + DiscussServiceClient.parse_common_folder_path + ) + common_organization_path = staticmethod( + DiscussServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + DiscussServiceClient.parse_common_organization_path + ) + common_project_path = staticmethod(DiscussServiceClient.common_project_path) + parse_common_project_path = staticmethod( + DiscussServiceClient.parse_common_project_path + ) + common_location_path = staticmethod(DiscussServiceClient.common_location_path) + parse_common_location_path = staticmethod( + DiscussServiceClient.parse_common_location_path + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DiscussServiceAsyncClient: The constructed client. + """ + return DiscussServiceClient.from_service_account_info.__func__(DiscussServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DiscussServiceAsyncClient: The constructed client. + """ + return DiscussServiceClient.from_service_account_file.__func__(DiscussServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return DiscussServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> DiscussServiceTransport: + """Returns the transport used by the client instance. + + Returns: + DiscussServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._client._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used + by the client instance. + """ + return self._client._universe_domain + + get_transport_class = DiscussServiceClient.get_transport_class + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[str, DiscussServiceTransport, Callable[..., DiscussServiceTransport]] + ] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the discuss service async client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,DiscussServiceTransport,Callable[..., DiscussServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the DiscussServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = DiscussServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.DiscussServiceAsyncClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "universeDomain": getattr( + self._client._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._client._transport._credentials).__module__}.{type(self._client._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._client._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "credentialsType": None, + }, + ) + + async def generate_message( + self, + request: Optional[Union[discuss_service.GenerateMessageRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> discuss_service.GenerateMessageResponse: + r"""Generates a response from the model given an input + ``MessagePrompt``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_generate_message(): + # Create a client + client = generativelanguage_v1alpha.DiscussServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1alpha.GenerateMessageRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.generate_message(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.GenerateMessageRequest, dict]]): + The request object. Request to generate a message + response from the model. + model (:class:`str`): + Required. The name of the model to use. + + Format: ``name=models/{model}``. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (:class:`google.ai.generativelanguage_v1alpha.types.MessagePrompt`): + Required. The structured textual + input given to the model as a prompt. + Given a + prompt, the model will return what it + predicts is the next message in the + discussion. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + temperature (:class:`float`): + Optional. Controls the randomness of the output. + + Values can range over ``[0.0,1.0]``, inclusive. A value + closer to ``1.0`` will produce responses that are more + varied, while a value closer to ``0.0`` will typically + result in less surprising responses from the model. + + This corresponds to the ``temperature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + candidate_count (:class:`int`): + Optional. The number of generated response messages to + return. + + This value must be between ``[1, 8]``, inclusive. If + unset, this will default to ``1``. + + This corresponds to the ``candidate_count`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_p (:class:`float`): + Optional. The maximum cumulative probability of tokens + to consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Nucleus sampling considers the smallest set of tokens + whose probability sum is at least ``top_p``. + + This corresponds to the ``top_p`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_k (:class:`int`): + Optional. The maximum number of tokens to consider when + sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most + probable tokens. + + This corresponds to the ``top_k`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.GenerateMessageResponse: + The response from the model. + + This includes candidate messages and + conversation history in the form of + chronologically-ordered messages. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any( + [model, prompt, temperature, candidate_count, top_p, top_k] + ) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, discuss_service.GenerateMessageRequest): + request = discuss_service.GenerateMessageRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + if temperature is not None: + request.temperature = temperature + if candidate_count is not None: + request.candidate_count = candidate_count + if top_p is not None: + request.top_p = top_p + if top_k is not None: + request.top_k = top_k + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.generate_message + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def count_message_tokens( + self, + request: Optional[ + Union[discuss_service.CountMessageTokensRequest, dict] + ] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> discuss_service.CountMessageTokensResponse: + r"""Runs a model's tokenizer on a string and returns the + token count. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_count_message_tokens(): + # Create a client + client = generativelanguage_v1alpha.DiscussServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1alpha.CountMessageTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.count_message_tokens(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.CountMessageTokensRequest, dict]]): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + model (:class:`str`): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (:class:`google.ai.generativelanguage_v1alpha.types.MessagePrompt`): + Required. The prompt, whose token + count is to be returned. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.CountMessageTokensResponse: + A response from CountMessageTokens. + + It returns the model's token_count for the prompt. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, discuss_service.CountMessageTokensRequest): + request = discuss_service.CountMessageTokensRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.count_message_tokens + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "DiscussServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("DiscussServiceAsyncClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/client.py new file mode 100644 index 000000000000..b68a0d415f6a --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/client.py @@ -0,0 +1,1128 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import os +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) +import warnings + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + +from google.longrunning import operations_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.types import discuss_service, safety + +from .transports.base import DEFAULT_CLIENT_INFO, DiscussServiceTransport +from .transports.grpc import DiscussServiceGrpcTransport +from .transports.grpc_asyncio import DiscussServiceGrpcAsyncIOTransport +from .transports.rest import DiscussServiceRestTransport + + +class DiscussServiceClientMeta(type): + """Metaclass for the DiscussService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[DiscussServiceTransport]] + _transport_registry["grpc"] = DiscussServiceGrpcTransport + _transport_registry["grpc_asyncio"] = DiscussServiceGrpcAsyncIOTransport + _transport_registry["rest"] = DiscussServiceRestTransport + + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[DiscussServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class DiscussServiceClient(metaclass=DiscussServiceClientMeta): + """An API for using Generative Language Models (GLMs) in dialog + applications. + Also known as large language models (LLMs), this API provides + models that are trained for multi-turn dialog. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + _DEFAULT_ENDPOINT_TEMPLATE = "generativelanguage.{UNIVERSE_DOMAIN}" + _DEFAULT_UNIVERSE = "googleapis.com" + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DiscussServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + DiscussServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> DiscussServiceTransport: + """Returns the transport used by the client instance. + + Returns: + DiscussServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def model_path( + model: str, + ) -> str: + """Returns a fully-qualified model string.""" + return "models/{model}".format( + model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parses a model path into its component segments.""" + m = re.match(r"^models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path( + billing_account: str, + ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path( + folder: str, + ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format( + folder=folder, + ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path( + organization: str, + ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format( + organization=organization, + ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path( + project: str, + ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format( + project=project, + ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path( + project: str, + location: str, + ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Deprecated. Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + + warnings.warn( + "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.", + DeprecationWarning, + ) + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + @staticmethod + def _read_environment_variables(): + """Returns the environment variables used by the client. + + Returns: + Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE, + GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables. + + Raises: + ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not + any of ["true", "false"]. + google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT + is not any of ["auto", "never", "always"]. + """ + use_client_cert = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() + universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + + @staticmethod + def _get_client_cert_source(provided_cert_source, use_cert_flag): + """Return the client cert source to be used by the client. + + Args: + provided_cert_source (bytes): The client certificate source provided. + use_cert_flag (bool): A flag indicating whether to use the client certificate. + + Returns: + bytes or None: The client cert source to be used by the client. + """ + client_cert_source = None + if use_cert_flag: + if provided_cert_source: + client_cert_source = provided_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + return client_cert_source + + @staticmethod + def _get_api_endpoint( + api_override, client_cert_source, universe_domain, use_mtls_endpoint + ): + """Return the API endpoint used by the client. + + Args: + api_override (str): The API endpoint override. If specified, this is always + the return value of this function and the other arguments are not used. + client_cert_source (bytes): The client certificate source used by the client. + universe_domain (str): The universe domain used by the client. + use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters. + Possible values are "always", "auto", or "never". + + Returns: + str: The API endpoint to be used by the client. + """ + if api_override is not None: + api_endpoint = api_override + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + _default_universe = DiscussServiceClient._DEFAULT_UNIVERSE + if universe_domain != _default_universe: + raise MutualTLSChannelError( + f"mTLS is not supported in any universe other than {_default_universe}." + ) + api_endpoint = DiscussServiceClient.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = DiscussServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=universe_domain + ) + return api_endpoint + + @staticmethod + def _get_universe_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] + ) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = DiscussServiceClient._DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise ValueError("Universe Domain cannot be an empty string.") + return universe_domain + + def _validate_universe_domain(self): + """Validates client's and credentials' universe domains are consistent. + + Returns: + bool: True iff the configured universe domain is valid. + + Raises: + ValueError: If the configured universe domain is not valid. + """ + + # NOTE (b/349488459): universe validation is disabled until further notice. + return True + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used by the client instance. + """ + return self._universe_domain + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[str, DiscussServiceTransport, Callable[..., DiscussServiceTransport]] + ] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the discuss service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,DiscussServiceTransport,Callable[..., DiscussServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the DiscussServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that the ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client_options = client_options + if isinstance(self._client_options, dict): + self._client_options = client_options_lib.from_dict(self._client_options) + if self._client_options is None: + self._client_options = client_options_lib.ClientOptions() + self._client_options = cast( + client_options_lib.ClientOptions, self._client_options + ) + + universe_domain_opt = getattr(self._client_options, "universe_domain", None) + + ( + self._use_client_cert, + self._use_mtls_endpoint, + self._universe_domain_env, + ) = DiscussServiceClient._read_environment_variables() + self._client_cert_source = DiscussServiceClient._get_client_cert_source( + self._client_options.client_cert_source, self._use_client_cert + ) + self._universe_domain = DiscussServiceClient._get_universe_domain( + universe_domain_opt, self._universe_domain_env + ) + self._api_endpoint = None # updated below, depending on `transport` + + # Initialize the universe domain validation. + self._is_universe_domain_valid = False + + if CLIENT_LOGGING_SUPPORTED: # pragma: NO COVER + # Setup logging. + client_logging.initialize_logging() + + api_key_value = getattr(self._client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + transport_provided = isinstance(transport, DiscussServiceTransport) + if transport_provided: + # transport is a DiscussServiceTransport instance. + if credentials or self._client_options.credentials_file or api_key_value: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if self._client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = cast(DiscussServiceTransport, transport) + self._api_endpoint = self._transport.host + + self._api_endpoint = ( + self._api_endpoint + or DiscussServiceClient._get_api_endpoint( + self._client_options.api_endpoint, + self._client_cert_source, + self._universe_domain, + self._use_mtls_endpoint, + ) + ) + + if not transport_provided: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + + transport_init: Union[ + Type[DiscussServiceTransport], Callable[..., DiscussServiceTransport] + ] = ( + DiscussServiceClient.get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., DiscussServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( + credentials=credentials, + credentials_file=self._client_options.credentials_file, + host=self._api_endpoint, + scopes=self._client_options.scopes, + client_cert_source_for_mtls=self._client_cert_source, + quota_project_id=self._client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=self._client_options.api_audience, + ) + + if "async" not in str(self._transport): + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.DiscussServiceClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "universeDomain": getattr( + self._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._transport._credentials).__module__}.{type(self._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "credentialsType": None, + }, + ) + + def generate_message( + self, + request: Optional[Union[discuss_service.GenerateMessageRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> discuss_service.GenerateMessageResponse: + r"""Generates a response from the model given an input + ``MessagePrompt``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_generate_message(): + # Create a client + client = generativelanguage_v1alpha.DiscussServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1alpha.GenerateMessageRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.generate_message(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.GenerateMessageRequest, dict]): + The request object. Request to generate a message + response from the model. + model (str): + Required. The name of the model to use. + + Format: ``name=models/{model}``. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (google.ai.generativelanguage_v1alpha.types.MessagePrompt): + Required. The structured textual + input given to the model as a prompt. + Given a + prompt, the model will return what it + predicts is the next message in the + discussion. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + temperature (float): + Optional. Controls the randomness of the output. + + Values can range over ``[0.0,1.0]``, inclusive. A value + closer to ``1.0`` will produce responses that are more + varied, while a value closer to ``0.0`` will typically + result in less surprising responses from the model. + + This corresponds to the ``temperature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + candidate_count (int): + Optional. The number of generated response messages to + return. + + This value must be between ``[1, 8]``, inclusive. If + unset, this will default to ``1``. + + This corresponds to the ``candidate_count`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_p (float): + Optional. The maximum cumulative probability of tokens + to consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Nucleus sampling considers the smallest set of tokens + whose probability sum is at least ``top_p``. + + This corresponds to the ``top_p`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_k (int): + Optional. The maximum number of tokens to consider when + sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most + probable tokens. + + This corresponds to the ``top_k`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.GenerateMessageResponse: + The response from the model. + + This includes candidate messages and + conversation history in the form of + chronologically-ordered messages. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any( + [model, prompt, temperature, candidate_count, top_p, top_k] + ) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, discuss_service.GenerateMessageRequest): + request = discuss_service.GenerateMessageRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + if temperature is not None: + request.temperature = temperature + if candidate_count is not None: + request.candidate_count = candidate_count + if top_p is not None: + request.top_p = top_p + if top_k is not None: + request.top_k = top_k + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.generate_message] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def count_message_tokens( + self, + request: Optional[ + Union[discuss_service.CountMessageTokensRequest, dict] + ] = None, + *, + model: Optional[str] = None, + prompt: Optional[discuss_service.MessagePrompt] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> discuss_service.CountMessageTokensResponse: + r"""Runs a model's tokenizer on a string and returns the + token count. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_count_message_tokens(): + # Create a client + client = generativelanguage_v1alpha.DiscussServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1alpha.CountMessageTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.count_message_tokens(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.CountMessageTokensRequest, dict]): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + model (str): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (google.ai.generativelanguage_v1alpha.types.MessagePrompt): + Required. The prompt, whose token + count is to be returned. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.CountMessageTokensResponse: + A response from CountMessageTokens. + + It returns the model's token_count for the prompt. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, discuss_service.CountMessageTokensRequest): + request = discuss_service.CountMessageTokensRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.count_message_tokens] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def __enter__(self) -> "DiscussServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("DiscussServiceClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/README.rst b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/README.rst new file mode 100644 index 000000000000..864ca745c717 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/README.rst @@ -0,0 +1,9 @@ + +transport inheritance structure +_______________________________ + +`DiscussServiceTransport` is the ABC for all transports. +- public child `DiscussServiceGrpcTransport` for sync gRPC transport (defined in `grpc.py`). +- public child `DiscussServiceGrpcAsyncIOTransport` for async gRPC transport (defined in `grpc_asyncio.py`). +- private child `_BaseDiscussServiceRestTransport` for base REST transport with inner classes `_BaseMETHOD` (defined in `rest_base.py`). +- public child `DiscussServiceRestTransport` for sync REST transport with inner classes `METHOD` derived from the parent's corresponding `_BaseMETHOD` classes (defined in `rest.py`). diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/__init__.py new file mode 100644 index 000000000000..05b2d4522c01 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/__init__.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import DiscussServiceTransport +from .grpc import DiscussServiceGrpcTransport +from .grpc_asyncio import DiscussServiceGrpcAsyncIOTransport +from .rest import DiscussServiceRestInterceptor, DiscussServiceRestTransport + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[DiscussServiceTransport]] +_transport_registry["grpc"] = DiscussServiceGrpcTransport +_transport_registry["grpc_asyncio"] = DiscussServiceGrpcAsyncIOTransport +_transport_registry["rest"] = DiscussServiceRestTransport + +__all__ = ( + "DiscussServiceTransport", + "DiscussServiceGrpcTransport", + "DiscussServiceGrpcAsyncIOTransport", + "DiscussServiceRestTransport", + "DiscussServiceRestInterceptor", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/base.py new file mode 100644 index 000000000000..345e7b694b51 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/base.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version +from google.ai.generativelanguage_v1alpha.types import discuss_service + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +class DiscussServiceTransport(abc.ABC): + """Abstract transport class for DiscussService.""" + + AUTH_SCOPES = () + + DEFAULT_HOST: str = "generativelanguage.googleapis.com" + + def __init__( + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + if not hasattr(self, "_ignore_credentials"): + self._ignore_credentials: bool = False + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) + elif credentials is None and not self._ignore_credentials: + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + @property + def host(self): + return self._host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.generate_message: gapic_v1.method.wrap_method( + self.generate_message, + default_timeout=None, + client_info=client_info, + ), + self.count_message_tokens: gapic_v1.method.wrap_method( + self.count_message_tokens, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: gapic_v1.method.wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: gapic_v1.method.wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def generate_message( + self, + ) -> Callable[ + [discuss_service.GenerateMessageRequest], + Union[ + discuss_service.GenerateMessageResponse, + Awaitable[discuss_service.GenerateMessageResponse], + ], + ]: + raise NotImplementedError() + + @property + def count_message_tokens( + self, + ) -> Callable[ + [discuss_service.CountMessageTokensRequest], + Union[ + discuss_service.CountMessageTokensResponse, + Awaitable[discuss_service.CountMessageTokensResponse], + ], + ]: + raise NotImplementedError() + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], + Union[ + operations_pb2.ListOperationsResponse, + Awaitable[operations_pb2.ListOperationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_operation( + self, + ) -> Callable[ + [operations_pb2.GetOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ("DiscussServiceTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/grpc.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/grpc.py new file mode 100644 index 000000000000..b614519801a2 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/grpc.py @@ -0,0 +1,431 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import logging as std_logging +import pickle +from typing import Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import gapic_v1, grpc_helpers +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import discuss_service + +from .base import DEFAULT_CLIENT_INFO, DiscussServiceTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientInterceptor(grpc.UnaryUnaryClientInterceptor): # pragma: NO COVER + def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "rpcName": client_call_details.method, + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + + response = continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = response.result() + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response for {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "rpcName": client_call_details.method, + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class DiscussServiceGrpcTransport(DiscussServiceTransport): + """gRPC backend transport for DiscussService. + + An API for using Generative Language Models (GLMs) in dialog + applications. + Also known as large language models (LLMs), this API provides + models that are trained for multi-turn dialog. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, grpc.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientInterceptor() + self._logged_channel = grpc.intercept_channel( + self._grpc_channel, self._interceptor + ) + + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service.""" + return self._grpc_channel + + @property + def generate_message( + self, + ) -> Callable[ + [discuss_service.GenerateMessageRequest], + discuss_service.GenerateMessageResponse, + ]: + r"""Return a callable for the generate message method over gRPC. + + Generates a response from the model given an input + ``MessagePrompt``. + + Returns: + Callable[[~.GenerateMessageRequest], + ~.GenerateMessageResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "generate_message" not in self._stubs: + self._stubs["generate_message"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.DiscussService/GenerateMessage", + request_serializer=discuss_service.GenerateMessageRequest.serialize, + response_deserializer=discuss_service.GenerateMessageResponse.deserialize, + ) + return self._stubs["generate_message"] + + @property + def count_message_tokens( + self, + ) -> Callable[ + [discuss_service.CountMessageTokensRequest], + discuss_service.CountMessageTokensResponse, + ]: + r"""Return a callable for the count message tokens method over gRPC. + + Runs a model's tokenizer on a string and returns the + token count. + + Returns: + Callable[[~.CountMessageTokensRequest], + ~.CountMessageTokensResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "count_message_tokens" not in self._stubs: + self._stubs["count_message_tokens"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.DiscussService/CountMessageTokens", + request_serializer=discuss_service.CountMessageTokensRequest.serialize, + response_deserializer=discuss_service.CountMessageTokensResponse.deserialize, + ) + return self._stubs["count_message_tokens"] + + def close(self): + self._logged_channel.close() + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ("DiscussServiceGrpcTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/grpc_asyncio.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..acd1b9564f63 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/grpc_asyncio.py @@ -0,0 +1,468 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import json +import logging as std_logging +import pickle +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, grpc_helpers_async +from google.api_core import retry_async as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +from grpc.experimental import aio # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import discuss_service + +from .base import DEFAULT_CLIENT_INFO, DiscussServiceTransport +from .grpc import DiscussServiceGrpcTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientAIOInterceptor( + grpc.aio.UnaryUnaryClientInterceptor +): # pragma: NO COVER + async def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "rpcName": str(client_call_details.method), + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + response = await continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = await response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = await response + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response to rpc {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "rpcName": str(client_call_details.method), + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class DiscussServiceGrpcAsyncIOTransport(DiscussServiceTransport): + """gRPC AsyncIO backend transport for DiscussService. + + An API for using Generative Language Models (GLMs) in dialog + applications. + Also known as large language models (LLMs), this API provides + models that are trained for multi-turn dialog. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, aio.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientAIOInterceptor() + self._grpc_channel._unary_unary_interceptors.append(self._interceptor) + self._logged_channel = self._grpc_channel + self._wrap_with_kind = ( + "kind" in inspect.signature(gapic_v1.method_async.wrap_method).parameters + ) + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def generate_message( + self, + ) -> Callable[ + [discuss_service.GenerateMessageRequest], + Awaitable[discuss_service.GenerateMessageResponse], + ]: + r"""Return a callable for the generate message method over gRPC. + + Generates a response from the model given an input + ``MessagePrompt``. + + Returns: + Callable[[~.GenerateMessageRequest], + Awaitable[~.GenerateMessageResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "generate_message" not in self._stubs: + self._stubs["generate_message"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.DiscussService/GenerateMessage", + request_serializer=discuss_service.GenerateMessageRequest.serialize, + response_deserializer=discuss_service.GenerateMessageResponse.deserialize, + ) + return self._stubs["generate_message"] + + @property + def count_message_tokens( + self, + ) -> Callable[ + [discuss_service.CountMessageTokensRequest], + Awaitable[discuss_service.CountMessageTokensResponse], + ]: + r"""Return a callable for the count message tokens method over gRPC. + + Runs a model's tokenizer on a string and returns the + token count. + + Returns: + Callable[[~.CountMessageTokensRequest], + Awaitable[~.CountMessageTokensResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "count_message_tokens" not in self._stubs: + self._stubs["count_message_tokens"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.DiscussService/CountMessageTokens", + request_serializer=discuss_service.CountMessageTokensRequest.serialize, + response_deserializer=discuss_service.CountMessageTokensResponse.deserialize, + ) + return self._stubs["count_message_tokens"] + + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.generate_message: self._wrap_method( + self.generate_message, + default_timeout=None, + client_info=client_info, + ), + self.count_message_tokens: self._wrap_method( + self.count_message_tokens, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: self._wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: self._wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def _wrap_method(self, func, *args, **kwargs): + if self._wrap_with_kind: # pragma: NO COVER + kwargs["kind"] = self.kind + return gapic_v1.method_async.wrap_method(func, *args, **kwargs) + + def close(self): + return self._logged_channel.close() + + @property + def kind(self) -> str: + return "grpc_asyncio" + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + +__all__ = ("DiscussServiceGrpcAsyncIOTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/rest.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/rest.py new file mode 100644 index 000000000000..f8b214dd0ac0 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/rest.py @@ -0,0 +1,909 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +import json # type: ignore +import logging +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, rest_helpers, rest_streaming +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.requests import AuthorizedSession # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import json_format +from requests import __version__ as requests_version + +from google.ai.generativelanguage_v1alpha.types import discuss_service + +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from .rest_base import _BaseDiscussServiceRestTransport + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = logging.getLogger(__name__) + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=f"requests@{requests_version}", +) + + +class DiscussServiceRestInterceptor: + """Interceptor for DiscussService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the DiscussServiceRestTransport. + + .. code-block:: python + class MyCustomDiscussServiceInterceptor(DiscussServiceRestInterceptor): + def pre_count_message_tokens(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_count_message_tokens(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_generate_message(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_generate_message(self, response): + logging.log(f"Received response: {response}") + return response + + transport = DiscussServiceRestTransport(interceptor=MyCustomDiscussServiceInterceptor()) + client = DiscussServiceClient(transport=transport) + + + """ + + def pre_count_message_tokens( + self, + request: discuss_service.CountMessageTokensRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + discuss_service.CountMessageTokensRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for count_message_tokens + + Override in a subclass to manipulate the request or metadata + before they are sent to the DiscussService server. + """ + return request, metadata + + def post_count_message_tokens( + self, response: discuss_service.CountMessageTokensResponse + ) -> discuss_service.CountMessageTokensResponse: + """Post-rpc interceptor for count_message_tokens + + Override in a subclass to manipulate the response + after it is returned by the DiscussService server but before + it is returned to user code. + """ + return response + + def pre_generate_message( + self, + request: discuss_service.GenerateMessageRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + discuss_service.GenerateMessageRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for generate_message + + Override in a subclass to manipulate the request or metadata + before they are sent to the DiscussService server. + """ + return request, metadata + + def post_generate_message( + self, response: discuss_service.GenerateMessageResponse + ) -> discuss_service.GenerateMessageResponse: + """Post-rpc interceptor for generate_message + + Override in a subclass to manipulate the response + after it is returned by the DiscussService server but before + it is returned to user code. + """ + return response + + def pre_get_operation( + self, + request: operations_pb2.GetOperationRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.GetOperationRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for get_operation + + Override in a subclass to manipulate the request or metadata + before they are sent to the DiscussService server. + """ + return request, metadata + + def post_get_operation( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for get_operation + + Override in a subclass to manipulate the response + after it is returned by the DiscussService server but before + it is returned to user code. + """ + return response + + def pre_list_operations( + self, + request: operations_pb2.ListOperationsRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.ListOperationsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for list_operations + + Override in a subclass to manipulate the request or metadata + before they are sent to the DiscussService server. + """ + return request, metadata + + def post_list_operations( + self, response: operations_pb2.ListOperationsResponse + ) -> operations_pb2.ListOperationsResponse: + """Post-rpc interceptor for list_operations + + Override in a subclass to manipulate the response + after it is returned by the DiscussService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class DiscussServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: DiscussServiceRestInterceptor + + +class DiscussServiceRestTransport(_BaseDiscussServiceRestTransport): + """REST backend synchronous transport for DiscussService. + + An API for using Generative Language Models (GLMs) in dialog + applications. + Also known as large language models (LLMs), this API provides + models that are trained for multi-turn dialog. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[DiscussServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + url_scheme=url_scheme, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or DiscussServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _CountMessageTokens( + _BaseDiscussServiceRestTransport._BaseCountMessageTokens, DiscussServiceRestStub + ): + def __hash__(self): + return hash("DiscussServiceRestTransport.CountMessageTokens") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: discuss_service.CountMessageTokensRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> discuss_service.CountMessageTokensResponse: + r"""Call the count message tokens method over HTTP. + + Args: + request (~.discuss_service.CountMessageTokensRequest): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.discuss_service.CountMessageTokensResponse: + A response from ``CountMessageTokens``. + + It returns the model's ``token_count`` for the + ``prompt``. + + """ + + http_options = ( + _BaseDiscussServiceRestTransport._BaseCountMessageTokens._get_http_options() + ) + + request, metadata = self._interceptor.pre_count_message_tokens( + request, metadata + ) + transcoded_request = _BaseDiscussServiceRestTransport._BaseCountMessageTokens._get_transcoded_request( + http_options, request + ) + + body = _BaseDiscussServiceRestTransport._BaseCountMessageTokens._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseDiscussServiceRestTransport._BaseCountMessageTokens._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.DiscussServiceClient.CountMessageTokens", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "rpcName": "CountMessageTokens", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = DiscussServiceRestTransport._CountMessageTokens._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = discuss_service.CountMessageTokensResponse() + pb_resp = discuss_service.CountMessageTokensResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_count_message_tokens(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = ( + discuss_service.CountMessageTokensResponse.to_json(response) + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.DiscussServiceClient.count_message_tokens", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "rpcName": "CountMessageTokens", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _GenerateMessage( + _BaseDiscussServiceRestTransport._BaseGenerateMessage, DiscussServiceRestStub + ): + def __hash__(self): + return hash("DiscussServiceRestTransport.GenerateMessage") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: discuss_service.GenerateMessageRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> discuss_service.GenerateMessageResponse: + r"""Call the generate message method over HTTP. + + Args: + request (~.discuss_service.GenerateMessageRequest): + The request object. Request to generate a message + response from the model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.discuss_service.GenerateMessageResponse: + The response from the model. + + This includes candidate messages and + conversation history in the form of + chronologically-ordered messages. + + """ + + http_options = ( + _BaseDiscussServiceRestTransport._BaseGenerateMessage._get_http_options() + ) + + request, metadata = self._interceptor.pre_generate_message( + request, metadata + ) + transcoded_request = _BaseDiscussServiceRestTransport._BaseGenerateMessage._get_transcoded_request( + http_options, request + ) + + body = _BaseDiscussServiceRestTransport._BaseGenerateMessage._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseDiscussServiceRestTransport._BaseGenerateMessage._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.DiscussServiceClient.GenerateMessage", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "rpcName": "GenerateMessage", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = DiscussServiceRestTransport._GenerateMessage._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = discuss_service.GenerateMessageResponse() + pb_resp = discuss_service.GenerateMessageResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_generate_message(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = discuss_service.GenerateMessageResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.DiscussServiceClient.generate_message", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "rpcName": "GenerateMessage", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + @property + def count_message_tokens( + self, + ) -> Callable[ + [discuss_service.CountMessageTokensRequest], + discuss_service.CountMessageTokensResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CountMessageTokens(self._session, self._host, self._interceptor) # type: ignore + + @property + def generate_message( + self, + ) -> Callable[ + [discuss_service.GenerateMessageRequest], + discuss_service.GenerateMessageResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GenerateMessage(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_operation(self): + return self._GetOperation(self._session, self._host, self._interceptor) # type: ignore + + class _GetOperation( + _BaseDiscussServiceRestTransport._BaseGetOperation, DiscussServiceRestStub + ): + def __hash__(self): + return hash("DiscussServiceRestTransport.GetOperation") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.GetOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Call the get operation method over HTTP. + + Args: + request (operations_pb2.GetOperationRequest): + The request object for GetOperation method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.Operation: Response from GetOperation method. + """ + + http_options = ( + _BaseDiscussServiceRestTransport._BaseGetOperation._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_operation(request, metadata) + transcoded_request = _BaseDiscussServiceRestTransport._BaseGetOperation._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseDiscussServiceRestTransport._BaseGetOperation._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.DiscussServiceClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "rpcName": "GetOperation", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = DiscussServiceRestTransport._GetOperation._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.Operation() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_get_operation(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.DiscussServiceAsyncClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "rpcName": "GetOperation", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def list_operations(self): + return self._ListOperations(self._session, self._host, self._interceptor) # type: ignore + + class _ListOperations( + _BaseDiscussServiceRestTransport._BaseListOperations, DiscussServiceRestStub + ): + def __hash__(self): + return hash("DiscussServiceRestTransport.ListOperations") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.ListOperationsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Call the list operations method over HTTP. + + Args: + request (operations_pb2.ListOperationsRequest): + The request object for ListOperations method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.ListOperationsResponse: Response from ListOperations method. + """ + + http_options = ( + _BaseDiscussServiceRestTransport._BaseListOperations._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_operations(request, metadata) + transcoded_request = _BaseDiscussServiceRestTransport._BaseListOperations._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseDiscussServiceRestTransport._BaseListOperations._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.DiscussServiceClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "rpcName": "ListOperations", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = DiscussServiceRestTransport._ListOperations._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.ListOperationsResponse() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_list_operations(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.DiscussServiceAsyncClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.DiscussService", + "rpcName": "ListOperations", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("DiscussServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/rest_base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/rest_base.py new file mode 100644 index 000000000000..863221cd4479 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/discuss_service/transports/rest_base.py @@ -0,0 +1,268 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json # type: ignore +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1, path_template +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import json_format + +from google.ai.generativelanguage_v1alpha.types import discuss_service + +from .base import DEFAULT_CLIENT_INFO, DiscussServiceTransport + + +class _BaseDiscussServiceRestTransport(DiscussServiceTransport): + """Base REST backend transport for DiscussService. + + Note: This class is not meant to be used directly. Use its sync and + async sub-classes instead. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[Any] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[Any]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + class _BaseCountMessageTokens: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{model=models/*}:countMessageTokens", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = discuss_service.CountMessageTokensRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseDiscussServiceRestTransport._BaseCountMessageTokens._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGenerateMessage: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{model=models/*}:generateMessage", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = discuss_service.GenerateMessageRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseDiscussServiceRestTransport._BaseGenerateMessage._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetOperation: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=generatedFiles/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*/operations/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + class _BaseListOperations: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*}/operations", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*}/operations", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + +__all__ = ("_BaseDiscussServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/__init__.py new file mode 100644 index 000000000000..ccbe9be60abc --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .async_client import FileServiceAsyncClient +from .client import FileServiceClient + +__all__ = ( + "FileServiceClient", + "FileServiceAsyncClient", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/async_client.py new file mode 100644 index 000000000000..2fb70165a176 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/async_client.py @@ -0,0 +1,780 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry_async as retries +from google.api_core.client_options import ClientOptions +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore + +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from google.rpc import status_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.file_service import pagers +from google.ai.generativelanguage_v1alpha.types import file, file_service + +from .client import FileServiceClient +from .transports.base import DEFAULT_CLIENT_INFO, FileServiceTransport +from .transports.grpc_asyncio import FileServiceGrpcAsyncIOTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class FileServiceAsyncClient: + """An API for uploading and managing files.""" + + _client: FileServiceClient + + # Copy defaults from the synchronous client for use here. + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = FileServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = FileServiceClient.DEFAULT_MTLS_ENDPOINT + _DEFAULT_ENDPOINT_TEMPLATE = FileServiceClient._DEFAULT_ENDPOINT_TEMPLATE + _DEFAULT_UNIVERSE = FileServiceClient._DEFAULT_UNIVERSE + + file_path = staticmethod(FileServiceClient.file_path) + parse_file_path = staticmethod(FileServiceClient.parse_file_path) + common_billing_account_path = staticmethod( + FileServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + FileServiceClient.parse_common_billing_account_path + ) + common_folder_path = staticmethod(FileServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(FileServiceClient.parse_common_folder_path) + common_organization_path = staticmethod(FileServiceClient.common_organization_path) + parse_common_organization_path = staticmethod( + FileServiceClient.parse_common_organization_path + ) + common_project_path = staticmethod(FileServiceClient.common_project_path) + parse_common_project_path = staticmethod( + FileServiceClient.parse_common_project_path + ) + common_location_path = staticmethod(FileServiceClient.common_location_path) + parse_common_location_path = staticmethod( + FileServiceClient.parse_common_location_path + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + FileServiceAsyncClient: The constructed client. + """ + return FileServiceClient.from_service_account_info.__func__(FileServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + FileServiceAsyncClient: The constructed client. + """ + return FileServiceClient.from_service_account_file.__func__(FileServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return FileServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> FileServiceTransport: + """Returns the transport used by the client instance. + + Returns: + FileServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._client._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used + by the client instance. + """ + return self._client._universe_domain + + get_transport_class = FileServiceClient.get_transport_class + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[str, FileServiceTransport, Callable[..., FileServiceTransport]] + ] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the file service async client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,FileServiceTransport,Callable[..., FileServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FileServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = FileServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.FileServiceAsyncClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "universeDomain": getattr( + self._client._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._client._transport._credentials).__module__}.{type(self._client._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._client._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "credentialsType": None, + }, + ) + + async def create_file( + self, + request: Optional[Union[file_service.CreateFileRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> file_service.CreateFileResponse: + r"""Creates a ``File``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_create_file(): + # Create a client + client = generativelanguage_v1alpha.FileServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateFileRequest( + ) + + # Make the request + response = await client.create_file(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.CreateFileRequest, dict]]): + The request object. Request for ``CreateFile``. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.CreateFileResponse: + Response for CreateFile. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, file_service.CreateFileRequest): + request = file_service.CreateFileRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_file + ] + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_files( + self, + request: Optional[Union[file_service.ListFilesRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListFilesAsyncPager: + r"""Lists the metadata for ``File``\ s owned by the requesting + project. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_list_files(): + # Create a client + client = generativelanguage_v1alpha.FileServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListFilesRequest( + ) + + # Make the request + page_result = client.list_files(request=request) + + # Handle the response + async for response in page_result: + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.ListFilesRequest, dict]]): + The request object. Request for ``ListFiles``. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.file_service.pagers.ListFilesAsyncPager: + Response for ListFiles. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, file_service.ListFilesRequest): + request = file_service.ListFilesRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_files + ] + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListFilesAsyncPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_file( + self, + request: Optional[Union[file_service.GetFileRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> file.File: + r"""Gets the metadata for the given ``File``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_get_file(): + # Create a client + client = generativelanguage_v1alpha.FileServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetFileRequest( + name="name_value", + ) + + # Make the request + response = await client.get_file(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.GetFileRequest, dict]]): + The request object. Request for ``GetFile``. + name (:class:`str`): + Required. The name of the ``File`` to get. Example: + ``files/abc-123`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.File: + A file uploaded to the API. + Next ID: 15 + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, file_service.GetFileRequest): + request = file_service.GetFileRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[self._client._transport.get_file] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_file( + self, + request: Optional[Union[file_service.DeleteFileRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Deletes the ``File``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_delete_file(): + # Create a client + client = generativelanguage_v1alpha.FileServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteFileRequest( + name="name_value", + ) + + # Make the request + await client.delete_file(request=request) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.DeleteFileRequest, dict]]): + The request object. Request for ``DeleteFile``. + name (:class:`str`): + Required. The name of the ``File`` to delete. Example: + ``files/abc-123`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, file_service.DeleteFileRequest): + request = file_service.DeleteFileRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_file + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "FileServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("FileServiceAsyncClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/client.py new file mode 100644 index 000000000000..84e6918cdba4 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/client.py @@ -0,0 +1,1165 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import os +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) +import warnings + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from google.rpc import status_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.file_service import pagers +from google.ai.generativelanguage_v1alpha.types import file, file_service + +from .transports.base import DEFAULT_CLIENT_INFO, FileServiceTransport +from .transports.grpc import FileServiceGrpcTransport +from .transports.grpc_asyncio import FileServiceGrpcAsyncIOTransport +from .transports.rest import FileServiceRestTransport + + +class FileServiceClientMeta(type): + """Metaclass for the FileService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = OrderedDict() # type: Dict[str, Type[FileServiceTransport]] + _transport_registry["grpc"] = FileServiceGrpcTransport + _transport_registry["grpc_asyncio"] = FileServiceGrpcAsyncIOTransport + _transport_registry["rest"] = FileServiceRestTransport + + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[FileServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class FileServiceClient(metaclass=FileServiceClientMeta): + """An API for uploading and managing files.""" + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + _DEFAULT_ENDPOINT_TEMPLATE = "generativelanguage.{UNIVERSE_DOMAIN}" + _DEFAULT_UNIVERSE = "googleapis.com" + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + FileServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + FileServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> FileServiceTransport: + """Returns the transport used by the client instance. + + Returns: + FileServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def file_path( + file: str, + ) -> str: + """Returns a fully-qualified file string.""" + return "files/{file}".format( + file=file, + ) + + @staticmethod + def parse_file_path(path: str) -> Dict[str, str]: + """Parses a file path into its component segments.""" + m = re.match(r"^files/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path( + billing_account: str, + ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path( + folder: str, + ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format( + folder=folder, + ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path( + organization: str, + ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format( + organization=organization, + ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path( + project: str, + ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format( + project=project, + ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path( + project: str, + location: str, + ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Deprecated. Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + + warnings.warn( + "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.", + DeprecationWarning, + ) + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + @staticmethod + def _read_environment_variables(): + """Returns the environment variables used by the client. + + Returns: + Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE, + GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables. + + Raises: + ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not + any of ["true", "false"]. + google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT + is not any of ["auto", "never", "always"]. + """ + use_client_cert = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() + universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + + @staticmethod + def _get_client_cert_source(provided_cert_source, use_cert_flag): + """Return the client cert source to be used by the client. + + Args: + provided_cert_source (bytes): The client certificate source provided. + use_cert_flag (bool): A flag indicating whether to use the client certificate. + + Returns: + bytes or None: The client cert source to be used by the client. + """ + client_cert_source = None + if use_cert_flag: + if provided_cert_source: + client_cert_source = provided_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + return client_cert_source + + @staticmethod + def _get_api_endpoint( + api_override, client_cert_source, universe_domain, use_mtls_endpoint + ): + """Return the API endpoint used by the client. + + Args: + api_override (str): The API endpoint override. If specified, this is always + the return value of this function and the other arguments are not used. + client_cert_source (bytes): The client certificate source used by the client. + universe_domain (str): The universe domain used by the client. + use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters. + Possible values are "always", "auto", or "never". + + Returns: + str: The API endpoint to be used by the client. + """ + if api_override is not None: + api_endpoint = api_override + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + _default_universe = FileServiceClient._DEFAULT_UNIVERSE + if universe_domain != _default_universe: + raise MutualTLSChannelError( + f"mTLS is not supported in any universe other than {_default_universe}." + ) + api_endpoint = FileServiceClient.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = FileServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=universe_domain + ) + return api_endpoint + + @staticmethod + def _get_universe_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] + ) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = FileServiceClient._DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise ValueError("Universe Domain cannot be an empty string.") + return universe_domain + + def _validate_universe_domain(self): + """Validates client's and credentials' universe domains are consistent. + + Returns: + bool: True iff the configured universe domain is valid. + + Raises: + ValueError: If the configured universe domain is not valid. + """ + + # NOTE (b/349488459): universe validation is disabled until further notice. + return True + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used by the client instance. + """ + return self._universe_domain + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[str, FileServiceTransport, Callable[..., FileServiceTransport]] + ] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the file service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,FileServiceTransport,Callable[..., FileServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the FileServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that the ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client_options = client_options + if isinstance(self._client_options, dict): + self._client_options = client_options_lib.from_dict(self._client_options) + if self._client_options is None: + self._client_options = client_options_lib.ClientOptions() + self._client_options = cast( + client_options_lib.ClientOptions, self._client_options + ) + + universe_domain_opt = getattr(self._client_options, "universe_domain", None) + + ( + self._use_client_cert, + self._use_mtls_endpoint, + self._universe_domain_env, + ) = FileServiceClient._read_environment_variables() + self._client_cert_source = FileServiceClient._get_client_cert_source( + self._client_options.client_cert_source, self._use_client_cert + ) + self._universe_domain = FileServiceClient._get_universe_domain( + universe_domain_opt, self._universe_domain_env + ) + self._api_endpoint = None # updated below, depending on `transport` + + # Initialize the universe domain validation. + self._is_universe_domain_valid = False + + if CLIENT_LOGGING_SUPPORTED: # pragma: NO COVER + # Setup logging. + client_logging.initialize_logging() + + api_key_value = getattr(self._client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + transport_provided = isinstance(transport, FileServiceTransport) + if transport_provided: + # transport is a FileServiceTransport instance. + if credentials or self._client_options.credentials_file or api_key_value: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if self._client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = cast(FileServiceTransport, transport) + self._api_endpoint = self._transport.host + + self._api_endpoint = self._api_endpoint or FileServiceClient._get_api_endpoint( + self._client_options.api_endpoint, + self._client_cert_source, + self._universe_domain, + self._use_mtls_endpoint, + ) + + if not transport_provided: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + + transport_init: Union[ + Type[FileServiceTransport], Callable[..., FileServiceTransport] + ] = ( + FileServiceClient.get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., FileServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( + credentials=credentials, + credentials_file=self._client_options.credentials_file, + host=self._api_endpoint, + scopes=self._client_options.scopes, + client_cert_source_for_mtls=self._client_cert_source, + quota_project_id=self._client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=self._client_options.api_audience, + ) + + if "async" not in str(self._transport): + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.FileServiceClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "universeDomain": getattr( + self._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._transport._credentials).__module__}.{type(self._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "credentialsType": None, + }, + ) + + def create_file( + self, + request: Optional[Union[file_service.CreateFileRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> file_service.CreateFileResponse: + r"""Creates a ``File``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_create_file(): + # Create a client + client = generativelanguage_v1alpha.FileServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateFileRequest( + ) + + # Make the request + response = client.create_file(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.CreateFileRequest, dict]): + The request object. Request for ``CreateFile``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.CreateFileResponse: + Response for CreateFile. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, file_service.CreateFileRequest): + request = file_service.CreateFileRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_file] + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_files( + self, + request: Optional[Union[file_service.ListFilesRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListFilesPager: + r"""Lists the metadata for ``File``\ s owned by the requesting + project. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_list_files(): + # Create a client + client = generativelanguage_v1alpha.FileServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListFilesRequest( + ) + + # Make the request + page_result = client.list_files(request=request) + + # Handle the response + for response in page_result: + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.ListFilesRequest, dict]): + The request object. Request for ``ListFiles``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.file_service.pagers.ListFilesPager: + Response for ListFiles. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, file_service.ListFilesRequest): + request = file_service.ListFilesRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_files] + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListFilesPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_file( + self, + request: Optional[Union[file_service.GetFileRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> file.File: + r"""Gets the metadata for the given ``File``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_get_file(): + # Create a client + client = generativelanguage_v1alpha.FileServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetFileRequest( + name="name_value", + ) + + # Make the request + response = client.get_file(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.GetFileRequest, dict]): + The request object. Request for ``GetFile``. + name (str): + Required. The name of the ``File`` to get. Example: + ``files/abc-123`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.File: + A file uploaded to the API. + Next ID: 15 + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, file_service.GetFileRequest): + request = file_service.GetFileRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_file] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_file( + self, + request: Optional[Union[file_service.DeleteFileRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Deletes the ``File``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_delete_file(): + # Create a client + client = generativelanguage_v1alpha.FileServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteFileRequest( + name="name_value", + ) + + # Make the request + client.delete_file(request=request) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.DeleteFileRequest, dict]): + The request object. Request for ``DeleteFile``. + name (str): + Required. The name of the ``File`` to delete. Example: + ``files/abc-123`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, file_service.DeleteFileRequest): + request = file_service.DeleteFileRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_file] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def __enter__(self) -> "FileServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("FileServiceClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/pagers.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/pagers.py new file mode 100644 index 000000000000..adff51f2d9a9 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/pagers.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Iterator, + Optional, + Sequence, + Tuple, + Union, +) + +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + +from google.ai.generativelanguage_v1alpha.types import file, file_service + + +class ListFilesPager: + """A pager for iterating through ``list_files`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListFilesResponse` object, and + provides an ``__iter__`` method to iterate through its + ``files`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListFiles`` requests and continue to iterate + through the ``files`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListFilesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., file_service.ListFilesResponse], + request: file_service.ListFilesRequest, + response: file_service.ListFilesResponse, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListFilesRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListFilesResponse): + The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = file_service.ListFilesRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[file_service.ListFilesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __iter__(self) -> Iterator[file.File]: + for page in self.pages: + yield from page.files + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListFilesAsyncPager: + """A pager for iterating through ``list_files`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListFilesResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``files`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListFiles`` requests and continue to iterate + through the ``files`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListFilesResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[file_service.ListFilesResponse]], + request: file_service.ListFilesRequest, + response: file_service.ListFilesResponse, + *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListFilesRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListFilesResponse): + The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = file_service.ListFilesRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[file_service.ListFilesResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __aiter__(self) -> AsyncIterator[file.File]: + async def async_generator(): + async for page in self.pages: + for response in page.files: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/README.rst b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/README.rst new file mode 100644 index 000000000000..f61177d3c2c0 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/README.rst @@ -0,0 +1,9 @@ + +transport inheritance structure +_______________________________ + +`FileServiceTransport` is the ABC for all transports. +- public child `FileServiceGrpcTransport` for sync gRPC transport (defined in `grpc.py`). +- public child `FileServiceGrpcAsyncIOTransport` for async gRPC transport (defined in `grpc_asyncio.py`). +- private child `_BaseFileServiceRestTransport` for base REST transport with inner classes `_BaseMETHOD` (defined in `rest_base.py`). +- public child `FileServiceRestTransport` for sync REST transport with inner classes `METHOD` derived from the parent's corresponding `_BaseMETHOD` classes (defined in `rest.py`). diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/__init__.py new file mode 100644 index 000000000000..4ea65f1e94d8 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/__init__.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import FileServiceTransport +from .grpc import FileServiceGrpcTransport +from .grpc_asyncio import FileServiceGrpcAsyncIOTransport +from .rest import FileServiceRestInterceptor, FileServiceRestTransport + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[FileServiceTransport]] +_transport_registry["grpc"] = FileServiceGrpcTransport +_transport_registry["grpc_asyncio"] = FileServiceGrpcAsyncIOTransport +_transport_registry["rest"] = FileServiceRestTransport + +__all__ = ( + "FileServiceTransport", + "FileServiceGrpcTransport", + "FileServiceGrpcAsyncIOTransport", + "FileServiceRestTransport", + "FileServiceRestInterceptor", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/base.py new file mode 100644 index 000000000000..6730b2082aef --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/base.py @@ -0,0 +1,239 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account # type: ignore +from google.protobuf import empty_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version +from google.ai.generativelanguage_v1alpha.types import file, file_service + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +class FileServiceTransport(abc.ABC): + """Abstract transport class for FileService.""" + + AUTH_SCOPES = () + + DEFAULT_HOST: str = "generativelanguage.googleapis.com" + + def __init__( + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + if not hasattr(self, "_ignore_credentials"): + self._ignore_credentials: bool = False + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) + elif credentials is None and not self._ignore_credentials: + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + @property + def host(self): + return self._host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_file: gapic_v1.method.wrap_method( + self.create_file, + default_timeout=None, + client_info=client_info, + ), + self.list_files: gapic_v1.method.wrap_method( + self.list_files, + default_timeout=None, + client_info=client_info, + ), + self.get_file: gapic_v1.method.wrap_method( + self.get_file, + default_timeout=None, + client_info=client_info, + ), + self.delete_file: gapic_v1.method.wrap_method( + self.delete_file, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: gapic_v1.method.wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: gapic_v1.method.wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def create_file( + self, + ) -> Callable[ + [file_service.CreateFileRequest], + Union[ + file_service.CreateFileResponse, Awaitable[file_service.CreateFileResponse] + ], + ]: + raise NotImplementedError() + + @property + def list_files( + self, + ) -> Callable[ + [file_service.ListFilesRequest], + Union[ + file_service.ListFilesResponse, Awaitable[file_service.ListFilesResponse] + ], + ]: + raise NotImplementedError() + + @property + def get_file( + self, + ) -> Callable[ + [file_service.GetFileRequest], Union[file.File, Awaitable[file.File]] + ]: + raise NotImplementedError() + + @property + def delete_file( + self, + ) -> Callable[ + [file_service.DeleteFileRequest], + Union[empty_pb2.Empty, Awaitable[empty_pb2.Empty]], + ]: + raise NotImplementedError() + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], + Union[ + operations_pb2.ListOperationsResponse, + Awaitable[operations_pb2.ListOperationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_operation( + self, + ) -> Callable[ + [operations_pb2.GetOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ("FileServiceTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/grpc.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/grpc.py new file mode 100644 index 000000000000..46de87d6157e --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/grpc.py @@ -0,0 +1,472 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import logging as std_logging +import pickle +from typing import Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import gapic_v1, grpc_helpers +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import file, file_service + +from .base import DEFAULT_CLIENT_INFO, FileServiceTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientInterceptor(grpc.UnaryUnaryClientInterceptor): # pragma: NO COVER + def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "rpcName": client_call_details.method, + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + + response = continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = response.result() + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response for {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "rpcName": client_call_details.method, + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class FileServiceGrpcTransport(FileServiceTransport): + """gRPC backend transport for FileService. + + An API for uploading and managing files. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, grpc.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientInterceptor() + self._logged_channel = grpc.intercept_channel( + self._grpc_channel, self._interceptor + ) + + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service.""" + return self._grpc_channel + + @property + def create_file( + self, + ) -> Callable[[file_service.CreateFileRequest], file_service.CreateFileResponse]: + r"""Return a callable for the create file method over gRPC. + + Creates a ``File``. + + Returns: + Callable[[~.CreateFileRequest], + ~.CreateFileResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_file" not in self._stubs: + self._stubs["create_file"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.FileService/CreateFile", + request_serializer=file_service.CreateFileRequest.serialize, + response_deserializer=file_service.CreateFileResponse.deserialize, + ) + return self._stubs["create_file"] + + @property + def list_files( + self, + ) -> Callable[[file_service.ListFilesRequest], file_service.ListFilesResponse]: + r"""Return a callable for the list files method over gRPC. + + Lists the metadata for ``File``\ s owned by the requesting + project. + + Returns: + Callable[[~.ListFilesRequest], + ~.ListFilesResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_files" not in self._stubs: + self._stubs["list_files"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.FileService/ListFiles", + request_serializer=file_service.ListFilesRequest.serialize, + response_deserializer=file_service.ListFilesResponse.deserialize, + ) + return self._stubs["list_files"] + + @property + def get_file(self) -> Callable[[file_service.GetFileRequest], file.File]: + r"""Return a callable for the get file method over gRPC. + + Gets the metadata for the given ``File``. + + Returns: + Callable[[~.GetFileRequest], + ~.File]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_file" not in self._stubs: + self._stubs["get_file"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.FileService/GetFile", + request_serializer=file_service.GetFileRequest.serialize, + response_deserializer=file.File.deserialize, + ) + return self._stubs["get_file"] + + @property + def delete_file( + self, + ) -> Callable[[file_service.DeleteFileRequest], empty_pb2.Empty]: + r"""Return a callable for the delete file method over gRPC. + + Deletes the ``File``. + + Returns: + Callable[[~.DeleteFileRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_file" not in self._stubs: + self._stubs["delete_file"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.FileService/DeleteFile", + request_serializer=file_service.DeleteFileRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["delete_file"] + + def close(self): + self._logged_channel.close() + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ("FileServiceGrpcTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/grpc_asyncio.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..c1d9aa30ce58 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/grpc_asyncio.py @@ -0,0 +1,523 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import json +import logging as std_logging +import pickle +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, grpc_helpers_async +from google.api_core import retry_async as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +from grpc.experimental import aio # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import file, file_service + +from .base import DEFAULT_CLIENT_INFO, FileServiceTransport +from .grpc import FileServiceGrpcTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientAIOInterceptor( + grpc.aio.UnaryUnaryClientInterceptor +): # pragma: NO COVER + async def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "rpcName": str(client_call_details.method), + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + response = await continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = await response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = await response + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response to rpc {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "rpcName": str(client_call_details.method), + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class FileServiceGrpcAsyncIOTransport(FileServiceTransport): + """gRPC AsyncIO backend transport for FileService. + + An API for uploading and managing files. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, aio.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientAIOInterceptor() + self._grpc_channel._unary_unary_interceptors.append(self._interceptor) + self._logged_channel = self._grpc_channel + self._wrap_with_kind = ( + "kind" in inspect.signature(gapic_v1.method_async.wrap_method).parameters + ) + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def create_file( + self, + ) -> Callable[ + [file_service.CreateFileRequest], Awaitable[file_service.CreateFileResponse] + ]: + r"""Return a callable for the create file method over gRPC. + + Creates a ``File``. + + Returns: + Callable[[~.CreateFileRequest], + Awaitable[~.CreateFileResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_file" not in self._stubs: + self._stubs["create_file"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.FileService/CreateFile", + request_serializer=file_service.CreateFileRequest.serialize, + response_deserializer=file_service.CreateFileResponse.deserialize, + ) + return self._stubs["create_file"] + + @property + def list_files( + self, + ) -> Callable[ + [file_service.ListFilesRequest], Awaitable[file_service.ListFilesResponse] + ]: + r"""Return a callable for the list files method over gRPC. + + Lists the metadata for ``File``\ s owned by the requesting + project. + + Returns: + Callable[[~.ListFilesRequest], + Awaitable[~.ListFilesResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_files" not in self._stubs: + self._stubs["list_files"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.FileService/ListFiles", + request_serializer=file_service.ListFilesRequest.serialize, + response_deserializer=file_service.ListFilesResponse.deserialize, + ) + return self._stubs["list_files"] + + @property + def get_file(self) -> Callable[[file_service.GetFileRequest], Awaitable[file.File]]: + r"""Return a callable for the get file method over gRPC. + + Gets the metadata for the given ``File``. + + Returns: + Callable[[~.GetFileRequest], + Awaitable[~.File]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_file" not in self._stubs: + self._stubs["get_file"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.FileService/GetFile", + request_serializer=file_service.GetFileRequest.serialize, + response_deserializer=file.File.deserialize, + ) + return self._stubs["get_file"] + + @property + def delete_file( + self, + ) -> Callable[[file_service.DeleteFileRequest], Awaitable[empty_pb2.Empty]]: + r"""Return a callable for the delete file method over gRPC. + + Deletes the ``File``. + + Returns: + Callable[[~.DeleteFileRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_file" not in self._stubs: + self._stubs["delete_file"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.FileService/DeleteFile", + request_serializer=file_service.DeleteFileRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["delete_file"] + + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_file: self._wrap_method( + self.create_file, + default_timeout=None, + client_info=client_info, + ), + self.list_files: self._wrap_method( + self.list_files, + default_timeout=None, + client_info=client_info, + ), + self.get_file: self._wrap_method( + self.get_file, + default_timeout=None, + client_info=client_info, + ), + self.delete_file: self._wrap_method( + self.delete_file, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: self._wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: self._wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def _wrap_method(self, func, *args, **kwargs): + if self._wrap_with_kind: # pragma: NO COVER + kwargs["kind"] = self.kind + return gapic_v1.method_async.wrap_method(func, *args, **kwargs) + + def close(self): + return self._logged_channel.close() + + @property + def kind(self) -> str: + return "grpc_asyncio" + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + +__all__ = ("FileServiceGrpcAsyncIOTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/rest.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/rest.py new file mode 100644 index 000000000000..4c47868cd4ca --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/rest.py @@ -0,0 +1,1191 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +import json # type: ignore +import logging +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, rest_helpers, rest_streaming +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.requests import AuthorizedSession # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import json_format +from requests import __version__ as requests_version + +from google.ai.generativelanguage_v1alpha.types import file, file_service + +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from .rest_base import _BaseFileServiceRestTransport + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = logging.getLogger(__name__) + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=f"requests@{requests_version}", +) + + +class FileServiceRestInterceptor: + """Interceptor for FileService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the FileServiceRestTransport. + + .. code-block:: python + class MyCustomFileServiceInterceptor(FileServiceRestInterceptor): + def pre_create_file(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_create_file(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_delete_file(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def pre_get_file(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_file(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_files(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_files(self, response): + logging.log(f"Received response: {response}") + return response + + transport = FileServiceRestTransport(interceptor=MyCustomFileServiceInterceptor()) + client = FileServiceClient(transport=transport) + + + """ + + def pre_create_file( + self, + request: file_service.CreateFileRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[file_service.CreateFileRequest, Sequence[Tuple[str, Union[str, bytes]]]]: + """Pre-rpc interceptor for create_file + + Override in a subclass to manipulate the request or metadata + before they are sent to the FileService server. + """ + return request, metadata + + def post_create_file( + self, response: file_service.CreateFileResponse + ) -> file_service.CreateFileResponse: + """Post-rpc interceptor for create_file + + Override in a subclass to manipulate the response + after it is returned by the FileService server but before + it is returned to user code. + """ + return response + + def pre_delete_file( + self, + request: file_service.DeleteFileRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[file_service.DeleteFileRequest, Sequence[Tuple[str, Union[str, bytes]]]]: + """Pre-rpc interceptor for delete_file + + Override in a subclass to manipulate the request or metadata + before they are sent to the FileService server. + """ + return request, metadata + + def pre_get_file( + self, + request: file_service.GetFileRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[file_service.GetFileRequest, Sequence[Tuple[str, Union[str, bytes]]]]: + """Pre-rpc interceptor for get_file + + Override in a subclass to manipulate the request or metadata + before they are sent to the FileService server. + """ + return request, metadata + + def post_get_file(self, response: file.File) -> file.File: + """Post-rpc interceptor for get_file + + Override in a subclass to manipulate the response + after it is returned by the FileService server but before + it is returned to user code. + """ + return response + + def pre_list_files( + self, + request: file_service.ListFilesRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[file_service.ListFilesRequest, Sequence[Tuple[str, Union[str, bytes]]]]: + """Pre-rpc interceptor for list_files + + Override in a subclass to manipulate the request or metadata + before they are sent to the FileService server. + """ + return request, metadata + + def post_list_files( + self, response: file_service.ListFilesResponse + ) -> file_service.ListFilesResponse: + """Post-rpc interceptor for list_files + + Override in a subclass to manipulate the response + after it is returned by the FileService server but before + it is returned to user code. + """ + return response + + def pre_get_operation( + self, + request: operations_pb2.GetOperationRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.GetOperationRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for get_operation + + Override in a subclass to manipulate the request or metadata + before they are sent to the FileService server. + """ + return request, metadata + + def post_get_operation( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for get_operation + + Override in a subclass to manipulate the response + after it is returned by the FileService server but before + it is returned to user code. + """ + return response + + def pre_list_operations( + self, + request: operations_pb2.ListOperationsRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.ListOperationsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for list_operations + + Override in a subclass to manipulate the request or metadata + before they are sent to the FileService server. + """ + return request, metadata + + def post_list_operations( + self, response: operations_pb2.ListOperationsResponse + ) -> operations_pb2.ListOperationsResponse: + """Post-rpc interceptor for list_operations + + Override in a subclass to manipulate the response + after it is returned by the FileService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class FileServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: FileServiceRestInterceptor + + +class FileServiceRestTransport(_BaseFileServiceRestTransport): + """REST backend synchronous transport for FileService. + + An API for uploading and managing files. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[FileServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + url_scheme=url_scheme, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or FileServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _CreateFile( + _BaseFileServiceRestTransport._BaseCreateFile, FileServiceRestStub + ): + def __hash__(self): + return hash("FileServiceRestTransport.CreateFile") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: file_service.CreateFileRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> file_service.CreateFileResponse: + r"""Call the create file method over HTTP. + + Args: + request (~.file_service.CreateFileRequest): + The request object. Request for ``CreateFile``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.file_service.CreateFileResponse: + Response for ``CreateFile``. + """ + + http_options = ( + _BaseFileServiceRestTransport._BaseCreateFile._get_http_options() + ) + + request, metadata = self._interceptor.pre_create_file(request, metadata) + transcoded_request = ( + _BaseFileServiceRestTransport._BaseCreateFile._get_transcoded_request( + http_options, request + ) + ) + + body = _BaseFileServiceRestTransport._BaseCreateFile._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = ( + _BaseFileServiceRestTransport._BaseCreateFile._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.FileServiceClient.CreateFile", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "rpcName": "CreateFile", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = FileServiceRestTransport._CreateFile._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = file_service.CreateFileResponse() + pb_resp = file_service.CreateFileResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_create_file(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = file_service.CreateFileResponse.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.FileServiceClient.create_file", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "rpcName": "CreateFile", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _DeleteFile( + _BaseFileServiceRestTransport._BaseDeleteFile, FileServiceRestStub + ): + def __hash__(self): + return hash("FileServiceRestTransport.DeleteFile") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: file_service.DeleteFileRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ): + r"""Call the delete file method over HTTP. + + Args: + request (~.file_service.DeleteFileRequest): + The request object. Request for ``DeleteFile``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + + http_options = ( + _BaseFileServiceRestTransport._BaseDeleteFile._get_http_options() + ) + + request, metadata = self._interceptor.pre_delete_file(request, metadata) + transcoded_request = ( + _BaseFileServiceRestTransport._BaseDeleteFile._get_transcoded_request( + http_options, request + ) + ) + + # Jsonify the query params + query_params = ( + _BaseFileServiceRestTransport._BaseDeleteFile._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.FileServiceClient.DeleteFile", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "rpcName": "DeleteFile", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = FileServiceRestTransport._DeleteFile._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + class _GetFile(_BaseFileServiceRestTransport._BaseGetFile, FileServiceRestStub): + def __hash__(self): + return hash("FileServiceRestTransport.GetFile") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: file_service.GetFileRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> file.File: + r"""Call the get file method over HTTP. + + Args: + request (~.file_service.GetFileRequest): + The request object. Request for ``GetFile``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.file.File: + A file uploaded to the API. + Next ID: 15 + + """ + + http_options = ( + _BaseFileServiceRestTransport._BaseGetFile._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_file(request, metadata) + transcoded_request = ( + _BaseFileServiceRestTransport._BaseGetFile._get_transcoded_request( + http_options, request + ) + ) + + # Jsonify the query params + query_params = ( + _BaseFileServiceRestTransport._BaseGetFile._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.FileServiceClient.GetFile", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "rpcName": "GetFile", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = FileServiceRestTransport._GetFile._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = file.File() + pb_resp = file.File.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_get_file(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = file.File.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.FileServiceClient.get_file", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "rpcName": "GetFile", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _ListFiles(_BaseFileServiceRestTransport._BaseListFiles, FileServiceRestStub): + def __hash__(self): + return hash("FileServiceRestTransport.ListFiles") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: file_service.ListFilesRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> file_service.ListFilesResponse: + r"""Call the list files method over HTTP. + + Args: + request (~.file_service.ListFilesRequest): + The request object. Request for ``ListFiles``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.file_service.ListFilesResponse: + Response for ``ListFiles``. + """ + + http_options = ( + _BaseFileServiceRestTransport._BaseListFiles._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_files(request, metadata) + transcoded_request = ( + _BaseFileServiceRestTransport._BaseListFiles._get_transcoded_request( + http_options, request + ) + ) + + # Jsonify the query params + query_params = ( + _BaseFileServiceRestTransport._BaseListFiles._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.FileServiceClient.ListFiles", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "rpcName": "ListFiles", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = FileServiceRestTransport._ListFiles._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = file_service.ListFilesResponse() + pb_resp = file_service.ListFilesResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_list_files(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = file_service.ListFilesResponse.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.FileServiceClient.list_files", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "rpcName": "ListFiles", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + @property + def create_file( + self, + ) -> Callable[[file_service.CreateFileRequest], file_service.CreateFileResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CreateFile(self._session, self._host, self._interceptor) # type: ignore + + @property + def delete_file( + self, + ) -> Callable[[file_service.DeleteFileRequest], empty_pb2.Empty]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DeleteFile(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_file(self) -> Callable[[file_service.GetFileRequest], file.File]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetFile(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_files( + self, + ) -> Callable[[file_service.ListFilesRequest], file_service.ListFilesResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListFiles(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_operation(self): + return self._GetOperation(self._session, self._host, self._interceptor) # type: ignore + + class _GetOperation( + _BaseFileServiceRestTransport._BaseGetOperation, FileServiceRestStub + ): + def __hash__(self): + return hash("FileServiceRestTransport.GetOperation") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.GetOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Call the get operation method over HTTP. + + Args: + request (operations_pb2.GetOperationRequest): + The request object for GetOperation method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.Operation: Response from GetOperation method. + """ + + http_options = ( + _BaseFileServiceRestTransport._BaseGetOperation._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_operation(request, metadata) + transcoded_request = ( + _BaseFileServiceRestTransport._BaseGetOperation._get_transcoded_request( + http_options, request + ) + ) + + # Jsonify the query params + query_params = ( + _BaseFileServiceRestTransport._BaseGetOperation._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.FileServiceClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "rpcName": "GetOperation", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = FileServiceRestTransport._GetOperation._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.Operation() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_get_operation(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.FileServiceAsyncClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "rpcName": "GetOperation", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def list_operations(self): + return self._ListOperations(self._session, self._host, self._interceptor) # type: ignore + + class _ListOperations( + _BaseFileServiceRestTransport._BaseListOperations, FileServiceRestStub + ): + def __hash__(self): + return hash("FileServiceRestTransport.ListOperations") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.ListOperationsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Call the list operations method over HTTP. + + Args: + request (operations_pb2.ListOperationsRequest): + The request object for ListOperations method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.ListOperationsResponse: Response from ListOperations method. + """ + + http_options = ( + _BaseFileServiceRestTransport._BaseListOperations._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_operations(request, metadata) + transcoded_request = _BaseFileServiceRestTransport._BaseListOperations._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseFileServiceRestTransport._BaseListOperations._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.FileServiceClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "rpcName": "ListOperations", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = FileServiceRestTransport._ListOperations._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.ListOperationsResponse() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_list_operations(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.FileServiceAsyncClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.FileService", + "rpcName": "ListOperations", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("FileServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/rest_base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/rest_base.py new file mode 100644 index 000000000000..0d97db6b8db0 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/file_service/transports/rest_base.py @@ -0,0 +1,323 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json # type: ignore +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1, path_template +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import json_format + +from google.ai.generativelanguage_v1alpha.types import file, file_service + +from .base import DEFAULT_CLIENT_INFO, FileServiceTransport + + +class _BaseFileServiceRestTransport(FileServiceTransport): + """Base REST backend transport for FileService. + + Note: This class is not meant to be used directly. Use its sync and + async sub-classes instead. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[Any] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[Any]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + class _BaseCreateFile: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/files", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = file_service.CreateFileRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseDeleteFile: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1alpha/{name=files/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = file_service.DeleteFileRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseFileServiceRestTransport._BaseDeleteFile._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetFile: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=files/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = file_service.GetFileRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseFileServiceRestTransport._BaseGetFile._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseListFiles: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/files", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = file_service.ListFilesRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetOperation: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=generatedFiles/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*/operations/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + class _BaseListOperations: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*}/operations", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*}/operations", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + +__all__ = ("_BaseFileServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/__init__.py new file mode 100644 index 000000000000..221b8e46bdaa --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .async_client import GenerativeServiceAsyncClient +from .client import GenerativeServiceClient + +__all__ = ( + "GenerativeServiceClient", + "GenerativeServiceAsyncClient", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/async_client.py new file mode 100644 index 000000000000..863c225af329 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/async_client.py @@ -0,0 +1,1366 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import re +from typing import ( + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry_async as retries +from google.api_core.client_options import ClientOptions +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore + +from google.longrunning import operations_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.types import generative_service, safety +from google.ai.generativelanguage_v1alpha.types import content +from google.ai.generativelanguage_v1alpha.types import content as gag_content + +from .client import GenerativeServiceClient +from .transports.base import DEFAULT_CLIENT_INFO, GenerativeServiceTransport +from .transports.grpc_asyncio import GenerativeServiceGrpcAsyncIOTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class GenerativeServiceAsyncClient: + """API for using Large Models that generate multimodal content + and have additional capabilities beyond text generation. + """ + + _client: GenerativeServiceClient + + # Copy defaults from the synchronous client for use here. + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = GenerativeServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = GenerativeServiceClient.DEFAULT_MTLS_ENDPOINT + _DEFAULT_ENDPOINT_TEMPLATE = GenerativeServiceClient._DEFAULT_ENDPOINT_TEMPLATE + _DEFAULT_UNIVERSE = GenerativeServiceClient._DEFAULT_UNIVERSE + + cached_content_path = staticmethod(GenerativeServiceClient.cached_content_path) + parse_cached_content_path = staticmethod( + GenerativeServiceClient.parse_cached_content_path + ) + model_path = staticmethod(GenerativeServiceClient.model_path) + parse_model_path = staticmethod(GenerativeServiceClient.parse_model_path) + common_billing_account_path = staticmethod( + GenerativeServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + GenerativeServiceClient.parse_common_billing_account_path + ) + common_folder_path = staticmethod(GenerativeServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + GenerativeServiceClient.parse_common_folder_path + ) + common_organization_path = staticmethod( + GenerativeServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + GenerativeServiceClient.parse_common_organization_path + ) + common_project_path = staticmethod(GenerativeServiceClient.common_project_path) + parse_common_project_path = staticmethod( + GenerativeServiceClient.parse_common_project_path + ) + common_location_path = staticmethod(GenerativeServiceClient.common_location_path) + parse_common_location_path = staticmethod( + GenerativeServiceClient.parse_common_location_path + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + GenerativeServiceAsyncClient: The constructed client. + """ + return GenerativeServiceClient.from_service_account_info.__func__(GenerativeServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + GenerativeServiceAsyncClient: The constructed client. + """ + return GenerativeServiceClient.from_service_account_file.__func__(GenerativeServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return GenerativeServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> GenerativeServiceTransport: + """Returns the transport used by the client instance. + + Returns: + GenerativeServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._client._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used + by the client instance. + """ + return self._client._universe_domain + + get_transport_class = GenerativeServiceClient.get_transport_class + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[ + str, + GenerativeServiceTransport, + Callable[..., GenerativeServiceTransport], + ] + ] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the generative service async client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,GenerativeServiceTransport,Callable[..., GenerativeServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the GenerativeServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = GenerativeServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "universeDomain": getattr( + self._client._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._client._transport._credentials).__module__}.{type(self._client._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._client._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "credentialsType": None, + }, + ) + + async def generate_content( + self, + request: Optional[ + Union[generative_service.GenerateContentRequest, dict] + ] = None, + *, + model: Optional[str] = None, + contents: Optional[MutableSequence[content.Content]] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> generative_service.GenerateContentResponse: + r"""Generates a model response given an input + ``GenerateContentRequest``. Refer to the `text generation + guide `__ + for detailed usage information. Input capabilities differ + between models, including tuned models. Refer to the `model + guide `__ + and `tuning + guide `__ + for details. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_generate_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GenerateContentRequest( + model="model_value", + ) + + # Make the request + response = await client.generate_content(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.GenerateContentRequest, dict]]): + The request object. Request to generate a completion from + the model. + model (:class:`str`): + Required. The name of the ``Model`` to use for + generating the completion. + + Format: ``models/{model}``. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + contents (:class:`MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]`): + Required. The content of the current conversation with + the model. + + For single-turn queries, this is a single instance. For + multi-turn queries like + `chat `__, + this is a repeated field that contains the conversation + history and the latest request. + + This corresponds to the ``contents`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.GenerateContentResponse: + Response from the model supporting multiple candidate + responses. + + Safety ratings and content filtering are reported for + both prompt in + GenerateContentResponse.prompt_feedback and for each + candidate in finish_reason and in safety_ratings. The + API: - Returns either all requested candidates or + none of them - Returns no candidates at all only if + there was something wrong with the prompt (check + prompt_feedback) - Reports feedback on each candidate + in finish_reason and safety_ratings. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, contents]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, generative_service.GenerateContentRequest): + request = generative_service.GenerateContentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if contents: + request.contents.extend(contents) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.generate_content + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def generate_answer( + self, + request: Optional[Union[generative_service.GenerateAnswerRequest, dict]] = None, + *, + model: Optional[str] = None, + contents: Optional[MutableSequence[content.Content]] = None, + safety_settings: Optional[MutableSequence[safety.SafetySetting]] = None, + answer_style: Optional[ + generative_service.GenerateAnswerRequest.AnswerStyle + ] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> generative_service.GenerateAnswerResponse: + r"""Generates a grounded answer from the model given an input + ``GenerateAnswerRequest``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_generate_answer(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GenerateAnswerRequest( + model="model_value", + answer_style="VERBOSE", + ) + + # Make the request + response = await client.generate_answer(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.GenerateAnswerRequest, dict]]): + The request object. Request to generate a grounded answer from the + ``Model``. + model (:class:`str`): + Required. The name of the ``Model`` to use for + generating the grounded response. + + Format: ``model=models/{model}``. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + contents (:class:`MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]`): + Required. The content of the current conversation with + the ``Model``. For single-turn queries, this is a single + question to answer. For multi-turn queries, this is a + repeated field that contains conversation history and + the last ``Content`` in the list containing the + question. + + Note: ``GenerateAnswer`` only supports queries in + English. + + This corresponds to the ``contents`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + safety_settings (:class:`MutableSequence[google.ai.generativelanguage_v1alpha.types.SafetySetting]`): + Optional. A list of unique ``SafetySetting`` instances + for blocking unsafe content. + + This will be enforced on the + ``GenerateAnswerRequest.contents`` and + ``GenerateAnswerResponse.candidate``. There should not + be more than one setting for each ``SafetyCategory`` + type. The API will block any contents and responses that + fail to meet the thresholds set by these settings. This + list overrides the default settings for each + ``SafetyCategory`` specified in the safety_settings. If + there is no ``SafetySetting`` for a given + ``SafetyCategory`` provided in the list, the API will + use the default safety setting for that category. Harm + categories HARM_CATEGORY_HATE_SPEECH, + HARM_CATEGORY_SEXUALLY_EXPLICIT, + HARM_CATEGORY_DANGEROUS_CONTENT, + HARM_CATEGORY_HARASSMENT are supported. Refer to the + `guide `__ + for detailed information on available safety settings. + Also refer to the `Safety + guidance `__ + to learn how to incorporate safety considerations in + your AI applications. + + This corresponds to the ``safety_settings`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + answer_style (:class:`google.ai.generativelanguage_v1alpha.types.GenerateAnswerRequest.AnswerStyle`): + Required. Style in which answers + should be returned. + + This corresponds to the ``answer_style`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.GenerateAnswerResponse: + Response from the model for a + grounded answer. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, contents, safety_settings, answer_style]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, generative_service.GenerateAnswerRequest): + request = generative_service.GenerateAnswerRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if answer_style is not None: + request.answer_style = answer_style + if contents: + request.contents.extend(contents) + if safety_settings: + request.safety_settings.extend(safety_settings) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.generate_answer + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def stream_generate_content( + self, + request: Optional[ + Union[generative_service.GenerateContentRequest, dict] + ] = None, + *, + model: Optional[str] = None, + contents: Optional[MutableSequence[content.Content]] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> Awaitable[AsyncIterable[generative_service.GenerateContentResponse]]: + r"""Generates a `streamed + response `__ + from the model given an input ``GenerateContentRequest``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_stream_generate_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GenerateContentRequest( + model="model_value", + ) + + # Make the request + stream = await client.stream_generate_content(request=request) + + # Handle the response + async for response in stream: + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.GenerateContentRequest, dict]]): + The request object. Request to generate a completion from + the model. + model (:class:`str`): + Required. The name of the ``Model`` to use for + generating the completion. + + Format: ``models/{model}``. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + contents (:class:`MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]`): + Required. The content of the current conversation with + the model. + + For single-turn queries, this is a single instance. For + multi-turn queries like + `chat `__, + this is a repeated field that contains the conversation + history and the latest request. + + This corresponds to the ``contents`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + AsyncIterable[google.ai.generativelanguage_v1alpha.types.GenerateContentResponse]: + Response from the model supporting multiple candidate + responses. + + Safety ratings and content filtering are reported for + both prompt in + GenerateContentResponse.prompt_feedback and for each + candidate in finish_reason and in safety_ratings. The + API: - Returns either all requested candidates or + none of them - Returns no candidates at all only if + there was something wrong with the prompt (check + prompt_feedback) - Reports feedback on each candidate + in finish_reason and safety_ratings. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, contents]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, generative_service.GenerateContentRequest): + request = generative_service.GenerateContentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if contents: + request.contents.extend(contents) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.stream_generate_content + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def embed_content( + self, + request: Optional[Union[generative_service.EmbedContentRequest, dict]] = None, + *, + model: Optional[str] = None, + content: Optional[gag_content.Content] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> generative_service.EmbedContentResponse: + r"""Generates a text embedding vector from the input ``Content`` + using the specified `Gemini Embedding + model `__. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_embed_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.EmbedContentRequest( + model="model_value", + ) + + # Make the request + response = await client.embed_content(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.EmbedContentRequest, dict]]): + The request object. Request containing the ``Content`` for the model to + embed. + model (:class:`str`): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + content (:class:`google.ai.generativelanguage_v1alpha.types.Content`): + Required. The content to embed. Only the ``parts.text`` + fields will be counted. + + This corresponds to the ``content`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.EmbedContentResponse: + The response to an EmbedContentRequest. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, content]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, generative_service.EmbedContentRequest): + request = generative_service.EmbedContentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if content is not None: + request.content = content + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.embed_content + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def batch_embed_contents( + self, + request: Optional[ + Union[generative_service.BatchEmbedContentsRequest, dict] + ] = None, + *, + model: Optional[str] = None, + requests: Optional[ + MutableSequence[generative_service.EmbedContentRequest] + ] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> generative_service.BatchEmbedContentsResponse: + r"""Generates multiple embedding vectors from the input ``Content`` + which consists of a batch of strings represented as + ``EmbedContentRequest`` objects. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_batch_embed_contents(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceAsyncClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.EmbedContentRequest() + requests.model = "model_value" + + request = generativelanguage_v1alpha.BatchEmbedContentsRequest( + model="model_value", + requests=requests, + ) + + # Make the request + response = await client.batch_embed_contents(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.BatchEmbedContentsRequest, dict]]): + The request object. Batch request to get embeddings from + the model for a list of prompts. + model (:class:`str`): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + requests (:class:`MutableSequence[google.ai.generativelanguage_v1alpha.types.EmbedContentRequest]`): + Required. Embed requests for the batch. The model in + each of these requests must match the model specified + ``BatchEmbedContentsRequest.model``. + + This corresponds to the ``requests`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.BatchEmbedContentsResponse: + The response to a BatchEmbedContentsRequest. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, requests]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, generative_service.BatchEmbedContentsRequest): + request = generative_service.BatchEmbedContentsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if requests: + request.requests.extend(requests) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_embed_contents + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def count_tokens( + self, + request: Optional[Union[generative_service.CountTokensRequest, dict]] = None, + *, + model: Optional[str] = None, + contents: Optional[MutableSequence[content.Content]] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> generative_service.CountTokensResponse: + r"""Runs a model's tokenizer on input ``Content`` and returns the + token count. Refer to the `tokens + guide `__ to learn + more about tokens. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_count_tokens(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CountTokensRequest( + model="model_value", + ) + + # Make the request + response = await client.count_tokens(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.CountTokensRequest, dict]]): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + model (:class:`str`): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + contents (:class:`MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]`): + Optional. The input given to the model as a prompt. This + field is ignored when ``generate_content_request`` is + set. + + This corresponds to the ``contents`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.CountTokensResponse: + A response from CountTokens. + + It returns the model's token_count for the prompt. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, contents]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, generative_service.CountTokensRequest): + request = generative_service.CountTokensRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if contents: + request.contents.extend(contents) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.count_tokens + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def bidi_generate_content( + self, + requests: Optional[ + AsyncIterator[generative_service.BidiGenerateContentClientMessage] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> Awaitable[AsyncIterable[generative_service.BidiGenerateContentServerMessage]]: + r"""Low-Latency bidirectional streaming API that supports + audio and video streaming inputs can produce multimodal + output streams (audio and text). + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_bidi_generate_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceAsyncClient() + + # Initialize request argument(s) + setup = generativelanguage_v1alpha.BidiGenerateContentSetup() + setup.model = "model_value" + + request = generativelanguage_v1alpha.BidiGenerateContentClientMessage( + setup=setup, + ) + + # This method expects an iterator which contains + # 'generativelanguage_v1alpha.BidiGenerateContentClientMessage' objects + # Here we create a generator that yields a single `request` for + # demonstrative purposes. + requests = [request] + + def request_generator(): + for request in requests: + yield request + + # Make the request + stream = await client.bidi_generate_content(requests=request_generator()) + + # Handle the response + async for response in stream: + print(response) + + Args: + requests (AsyncIterator[`google.ai.generativelanguage_v1alpha.types.BidiGenerateContentClientMessage`]): + The request object AsyncIterator. Messages sent by the client in the + BidiGenerateContent call. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + AsyncIterable[google.ai.generativelanguage_v1alpha.types.BidiGenerateContentServerMessage]: + Response message for the + BidiGenerateContent call. + + """ + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.bidi_generate_content + ] + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = rpc( + requests, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "GenerativeServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("GenerativeServiceAsyncClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/client.py new file mode 100644 index 000000000000..767855b5befb --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/client.py @@ -0,0 +1,1751 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import os +import re +from typing import ( + Callable, + Dict, + Iterable, + Iterator, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) +import warnings + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + +from google.longrunning import operations_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.types import generative_service, safety +from google.ai.generativelanguage_v1alpha.types import content +from google.ai.generativelanguage_v1alpha.types import content as gag_content + +from .transports.base import DEFAULT_CLIENT_INFO, GenerativeServiceTransport +from .transports.grpc import GenerativeServiceGrpcTransport +from .transports.grpc_asyncio import GenerativeServiceGrpcAsyncIOTransport +from .transports.rest import GenerativeServiceRestTransport + + +class GenerativeServiceClientMeta(type): + """Metaclass for the GenerativeService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[GenerativeServiceTransport]] + _transport_registry["grpc"] = GenerativeServiceGrpcTransport + _transport_registry["grpc_asyncio"] = GenerativeServiceGrpcAsyncIOTransport + _transport_registry["rest"] = GenerativeServiceRestTransport + + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[GenerativeServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class GenerativeServiceClient(metaclass=GenerativeServiceClientMeta): + """API for using Large Models that generate multimodal content + and have additional capabilities beyond text generation. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + _DEFAULT_ENDPOINT_TEMPLATE = "generativelanguage.{UNIVERSE_DOMAIN}" + _DEFAULT_UNIVERSE = "googleapis.com" + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + GenerativeServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + GenerativeServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> GenerativeServiceTransport: + """Returns the transport used by the client instance. + + Returns: + GenerativeServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def cached_content_path( + id: str, + ) -> str: + """Returns a fully-qualified cached_content string.""" + return "cachedContents/{id}".format( + id=id, + ) + + @staticmethod + def parse_cached_content_path(path: str) -> Dict[str, str]: + """Parses a cached_content path into its component segments.""" + m = re.match(r"^cachedContents/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def model_path( + model: str, + ) -> str: + """Returns a fully-qualified model string.""" + return "models/{model}".format( + model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parses a model path into its component segments.""" + m = re.match(r"^models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path( + billing_account: str, + ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path( + folder: str, + ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format( + folder=folder, + ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path( + organization: str, + ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format( + organization=organization, + ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path( + project: str, + ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format( + project=project, + ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path( + project: str, + location: str, + ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Deprecated. Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + + warnings.warn( + "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.", + DeprecationWarning, + ) + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + @staticmethod + def _read_environment_variables(): + """Returns the environment variables used by the client. + + Returns: + Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE, + GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables. + + Raises: + ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not + any of ["true", "false"]. + google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT + is not any of ["auto", "never", "always"]. + """ + use_client_cert = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() + universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + + @staticmethod + def _get_client_cert_source(provided_cert_source, use_cert_flag): + """Return the client cert source to be used by the client. + + Args: + provided_cert_source (bytes): The client certificate source provided. + use_cert_flag (bool): A flag indicating whether to use the client certificate. + + Returns: + bytes or None: The client cert source to be used by the client. + """ + client_cert_source = None + if use_cert_flag: + if provided_cert_source: + client_cert_source = provided_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + return client_cert_source + + @staticmethod + def _get_api_endpoint( + api_override, client_cert_source, universe_domain, use_mtls_endpoint + ): + """Return the API endpoint used by the client. + + Args: + api_override (str): The API endpoint override. If specified, this is always + the return value of this function and the other arguments are not used. + client_cert_source (bytes): The client certificate source used by the client. + universe_domain (str): The universe domain used by the client. + use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters. + Possible values are "always", "auto", or "never". + + Returns: + str: The API endpoint to be used by the client. + """ + if api_override is not None: + api_endpoint = api_override + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + _default_universe = GenerativeServiceClient._DEFAULT_UNIVERSE + if universe_domain != _default_universe: + raise MutualTLSChannelError( + f"mTLS is not supported in any universe other than {_default_universe}." + ) + api_endpoint = GenerativeServiceClient.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = GenerativeServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=universe_domain + ) + return api_endpoint + + @staticmethod + def _get_universe_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] + ) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = GenerativeServiceClient._DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise ValueError("Universe Domain cannot be an empty string.") + return universe_domain + + def _validate_universe_domain(self): + """Validates client's and credentials' universe domains are consistent. + + Returns: + bool: True iff the configured universe domain is valid. + + Raises: + ValueError: If the configured universe domain is not valid. + """ + + # NOTE (b/349488459): universe validation is disabled until further notice. + return True + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used by the client instance. + """ + return self._universe_domain + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[ + str, + GenerativeServiceTransport, + Callable[..., GenerativeServiceTransport], + ] + ] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the generative service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,GenerativeServiceTransport,Callable[..., GenerativeServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the GenerativeServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that the ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client_options = client_options + if isinstance(self._client_options, dict): + self._client_options = client_options_lib.from_dict(self._client_options) + if self._client_options is None: + self._client_options = client_options_lib.ClientOptions() + self._client_options = cast( + client_options_lib.ClientOptions, self._client_options + ) + + universe_domain_opt = getattr(self._client_options, "universe_domain", None) + + ( + self._use_client_cert, + self._use_mtls_endpoint, + self._universe_domain_env, + ) = GenerativeServiceClient._read_environment_variables() + self._client_cert_source = GenerativeServiceClient._get_client_cert_source( + self._client_options.client_cert_source, self._use_client_cert + ) + self._universe_domain = GenerativeServiceClient._get_universe_domain( + universe_domain_opt, self._universe_domain_env + ) + self._api_endpoint = None # updated below, depending on `transport` + + # Initialize the universe domain validation. + self._is_universe_domain_valid = False + + if CLIENT_LOGGING_SUPPORTED: # pragma: NO COVER + # Setup logging. + client_logging.initialize_logging() + + api_key_value = getattr(self._client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + transport_provided = isinstance(transport, GenerativeServiceTransport) + if transport_provided: + # transport is a GenerativeServiceTransport instance. + if credentials or self._client_options.credentials_file or api_key_value: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if self._client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = cast(GenerativeServiceTransport, transport) + self._api_endpoint = self._transport.host + + self._api_endpoint = ( + self._api_endpoint + or GenerativeServiceClient._get_api_endpoint( + self._client_options.api_endpoint, + self._client_cert_source, + self._universe_domain, + self._use_mtls_endpoint, + ) + ) + + if not transport_provided: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + + transport_init: Union[ + Type[GenerativeServiceTransport], + Callable[..., GenerativeServiceTransport], + ] = ( + GenerativeServiceClient.get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., GenerativeServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( + credentials=credentials, + credentials_file=self._client_options.credentials_file, + host=self._api_endpoint, + scopes=self._client_options.scopes, + client_cert_source_for_mtls=self._client_cert_source, + quota_project_id=self._client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=self._client_options.api_audience, + ) + + if "async" not in str(self._transport): + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.GenerativeServiceClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "universeDomain": getattr( + self._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._transport._credentials).__module__}.{type(self._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "credentialsType": None, + }, + ) + + def generate_content( + self, + request: Optional[ + Union[generative_service.GenerateContentRequest, dict] + ] = None, + *, + model: Optional[str] = None, + contents: Optional[MutableSequence[content.Content]] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> generative_service.GenerateContentResponse: + r"""Generates a model response given an input + ``GenerateContentRequest``. Refer to the `text generation + guide `__ + for detailed usage information. Input capabilities differ + between models, including tuned models. Refer to the `model + guide `__ + and `tuning + guide `__ + for details. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_generate_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GenerateContentRequest( + model="model_value", + ) + + # Make the request + response = client.generate_content(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.GenerateContentRequest, dict]): + The request object. Request to generate a completion from + the model. + model (str): + Required. The name of the ``Model`` to use for + generating the completion. + + Format: ``models/{model}``. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + contents (MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]): + Required. The content of the current conversation with + the model. + + For single-turn queries, this is a single instance. For + multi-turn queries like + `chat `__, + this is a repeated field that contains the conversation + history and the latest request. + + This corresponds to the ``contents`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.GenerateContentResponse: + Response from the model supporting multiple candidate + responses. + + Safety ratings and content filtering are reported for + both prompt in + GenerateContentResponse.prompt_feedback and for each + candidate in finish_reason and in safety_ratings. The + API: - Returns either all requested candidates or + none of them - Returns no candidates at all only if + there was something wrong with the prompt (check + prompt_feedback) - Reports feedback on each candidate + in finish_reason and safety_ratings. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, contents]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, generative_service.GenerateContentRequest): + request = generative_service.GenerateContentRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if contents is not None: + request.contents = contents + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.generate_content] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def generate_answer( + self, + request: Optional[Union[generative_service.GenerateAnswerRequest, dict]] = None, + *, + model: Optional[str] = None, + contents: Optional[MutableSequence[content.Content]] = None, + safety_settings: Optional[MutableSequence[safety.SafetySetting]] = None, + answer_style: Optional[ + generative_service.GenerateAnswerRequest.AnswerStyle + ] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> generative_service.GenerateAnswerResponse: + r"""Generates a grounded answer from the model given an input + ``GenerateAnswerRequest``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_generate_answer(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GenerateAnswerRequest( + model="model_value", + answer_style="VERBOSE", + ) + + # Make the request + response = client.generate_answer(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.GenerateAnswerRequest, dict]): + The request object. Request to generate a grounded answer from the + ``Model``. + model (str): + Required. The name of the ``Model`` to use for + generating the grounded response. + + Format: ``model=models/{model}``. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + contents (MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]): + Required. The content of the current conversation with + the ``Model``. For single-turn queries, this is a single + question to answer. For multi-turn queries, this is a + repeated field that contains conversation history and + the last ``Content`` in the list containing the + question. + + Note: ``GenerateAnswer`` only supports queries in + English. + + This corresponds to the ``contents`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + safety_settings (MutableSequence[google.ai.generativelanguage_v1alpha.types.SafetySetting]): + Optional. A list of unique ``SafetySetting`` instances + for blocking unsafe content. + + This will be enforced on the + ``GenerateAnswerRequest.contents`` and + ``GenerateAnswerResponse.candidate``. There should not + be more than one setting for each ``SafetyCategory`` + type. The API will block any contents and responses that + fail to meet the thresholds set by these settings. This + list overrides the default settings for each + ``SafetyCategory`` specified in the safety_settings. If + there is no ``SafetySetting`` for a given + ``SafetyCategory`` provided in the list, the API will + use the default safety setting for that category. Harm + categories HARM_CATEGORY_HATE_SPEECH, + HARM_CATEGORY_SEXUALLY_EXPLICIT, + HARM_CATEGORY_DANGEROUS_CONTENT, + HARM_CATEGORY_HARASSMENT are supported. Refer to the + `guide `__ + for detailed information on available safety settings. + Also refer to the `Safety + guidance `__ + to learn how to incorporate safety considerations in + your AI applications. + + This corresponds to the ``safety_settings`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + answer_style (google.ai.generativelanguage_v1alpha.types.GenerateAnswerRequest.AnswerStyle): + Required. Style in which answers + should be returned. + + This corresponds to the ``answer_style`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.GenerateAnswerResponse: + Response from the model for a + grounded answer. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, contents, safety_settings, answer_style]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, generative_service.GenerateAnswerRequest): + request = generative_service.GenerateAnswerRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if contents is not None: + request.contents = contents + if safety_settings is not None: + request.safety_settings = safety_settings + if answer_style is not None: + request.answer_style = answer_style + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.generate_answer] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def stream_generate_content( + self, + request: Optional[ + Union[generative_service.GenerateContentRequest, dict] + ] = None, + *, + model: Optional[str] = None, + contents: Optional[MutableSequence[content.Content]] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> Iterable[generative_service.GenerateContentResponse]: + r"""Generates a `streamed + response `__ + from the model given an input ``GenerateContentRequest``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_stream_generate_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GenerateContentRequest( + model="model_value", + ) + + # Make the request + stream = client.stream_generate_content(request=request) + + # Handle the response + for response in stream: + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.GenerateContentRequest, dict]): + The request object. Request to generate a completion from + the model. + model (str): + Required. The name of the ``Model`` to use for + generating the completion. + + Format: ``models/{model}``. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + contents (MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]): + Required. The content of the current conversation with + the model. + + For single-turn queries, this is a single instance. For + multi-turn queries like + `chat `__, + this is a repeated field that contains the conversation + history and the latest request. + + This corresponds to the ``contents`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + Iterable[google.ai.generativelanguage_v1alpha.types.GenerateContentResponse]: + Response from the model supporting multiple candidate + responses. + + Safety ratings and content filtering are reported for + both prompt in + GenerateContentResponse.prompt_feedback and for each + candidate in finish_reason and in safety_ratings. The + API: - Returns either all requested candidates or + none of them - Returns no candidates at all only if + there was something wrong with the prompt (check + prompt_feedback) - Reports feedback on each candidate + in finish_reason and safety_ratings. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, contents]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, generative_service.GenerateContentRequest): + request = generative_service.GenerateContentRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if contents is not None: + request.contents = contents + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.stream_generate_content] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def embed_content( + self, + request: Optional[Union[generative_service.EmbedContentRequest, dict]] = None, + *, + model: Optional[str] = None, + content: Optional[gag_content.Content] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> generative_service.EmbedContentResponse: + r"""Generates a text embedding vector from the input ``Content`` + using the specified `Gemini Embedding + model `__. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_embed_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.EmbedContentRequest( + model="model_value", + ) + + # Make the request + response = client.embed_content(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.EmbedContentRequest, dict]): + The request object. Request containing the ``Content`` for the model to + embed. + model (str): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + content (google.ai.generativelanguage_v1alpha.types.Content): + Required. The content to embed. Only the ``parts.text`` + fields will be counted. + + This corresponds to the ``content`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.EmbedContentResponse: + The response to an EmbedContentRequest. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, content]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, generative_service.EmbedContentRequest): + request = generative_service.EmbedContentRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if content is not None: + request.content = content + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.embed_content] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def batch_embed_contents( + self, + request: Optional[ + Union[generative_service.BatchEmbedContentsRequest, dict] + ] = None, + *, + model: Optional[str] = None, + requests: Optional[ + MutableSequence[generative_service.EmbedContentRequest] + ] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> generative_service.BatchEmbedContentsResponse: + r"""Generates multiple embedding vectors from the input ``Content`` + which consists of a batch of strings represented as + ``EmbedContentRequest`` objects. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_batch_embed_contents(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.EmbedContentRequest() + requests.model = "model_value" + + request = generativelanguage_v1alpha.BatchEmbedContentsRequest( + model="model_value", + requests=requests, + ) + + # Make the request + response = client.batch_embed_contents(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.BatchEmbedContentsRequest, dict]): + The request object. Batch request to get embeddings from + the model for a list of prompts. + model (str): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + requests (MutableSequence[google.ai.generativelanguage_v1alpha.types.EmbedContentRequest]): + Required. Embed requests for the batch. The model in + each of these requests must match the model specified + ``BatchEmbedContentsRequest.model``. + + This corresponds to the ``requests`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.BatchEmbedContentsResponse: + The response to a BatchEmbedContentsRequest. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, requests]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, generative_service.BatchEmbedContentsRequest): + request = generative_service.BatchEmbedContentsRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if requests is not None: + request.requests = requests + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.batch_embed_contents] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def count_tokens( + self, + request: Optional[Union[generative_service.CountTokensRequest, dict]] = None, + *, + model: Optional[str] = None, + contents: Optional[MutableSequence[content.Content]] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> generative_service.CountTokensResponse: + r"""Runs a model's tokenizer on input ``Content`` and returns the + token count. Refer to the `tokens + guide `__ to learn + more about tokens. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_count_tokens(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CountTokensRequest( + model="model_value", + ) + + # Make the request + response = client.count_tokens(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.CountTokensRequest, dict]): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + model (str): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + contents (MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]): + Optional. The input given to the model as a prompt. This + field is ignored when ``generate_content_request`` is + set. + + This corresponds to the ``contents`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.CountTokensResponse: + A response from CountTokens. + + It returns the model's token_count for the prompt. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, contents]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, generative_service.CountTokensRequest): + request = generative_service.CountTokensRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if contents is not None: + request.contents = contents + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.count_tokens] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def bidi_generate_content( + self, + requests: Optional[ + Iterator[generative_service.BidiGenerateContentClientMessage] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> Iterable[generative_service.BidiGenerateContentServerMessage]: + r"""Low-Latency bidirectional streaming API that supports + audio and video streaming inputs can produce multimodal + output streams (audio and text). + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_bidi_generate_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceClient() + + # Initialize request argument(s) + setup = generativelanguage_v1alpha.BidiGenerateContentSetup() + setup.model = "model_value" + + request = generativelanguage_v1alpha.BidiGenerateContentClientMessage( + setup=setup, + ) + + # This method expects an iterator which contains + # 'generativelanguage_v1alpha.BidiGenerateContentClientMessage' objects + # Here we create a generator that yields a single `request` for + # demonstrative purposes. + requests = [request] + + def request_generator(): + for request in requests: + yield request + + # Make the request + stream = client.bidi_generate_content(requests=request_generator()) + + # Handle the response + for response in stream: + print(response) + + Args: + requests (Iterator[google.ai.generativelanguage_v1alpha.types.BidiGenerateContentClientMessage]): + The request object iterator. Messages sent by the client in the + BidiGenerateContent call. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + Iterable[google.ai.generativelanguage_v1alpha.types.BidiGenerateContentServerMessage]: + Response message for the + BidiGenerateContent call. + + """ + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.bidi_generate_content] + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + requests, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def __enter__(self) -> "GenerativeServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("GenerativeServiceClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/README.rst b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/README.rst new file mode 100644 index 000000000000..b8180cab1e5c --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/README.rst @@ -0,0 +1,9 @@ + +transport inheritance structure +_______________________________ + +`GenerativeServiceTransport` is the ABC for all transports. +- public child `GenerativeServiceGrpcTransport` for sync gRPC transport (defined in `grpc.py`). +- public child `GenerativeServiceGrpcAsyncIOTransport` for async gRPC transport (defined in `grpc_asyncio.py`). +- private child `_BaseGenerativeServiceRestTransport` for base REST transport with inner classes `_BaseMETHOD` (defined in `rest_base.py`). +- public child `GenerativeServiceRestTransport` for sync REST transport with inner classes `METHOD` derived from the parent's corresponding `_BaseMETHOD` classes (defined in `rest.py`). diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/__init__.py new file mode 100644 index 000000000000..1321d30dc967 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/__init__.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import GenerativeServiceTransport +from .grpc import GenerativeServiceGrpcTransport +from .grpc_asyncio import GenerativeServiceGrpcAsyncIOTransport +from .rest import GenerativeServiceRestInterceptor, GenerativeServiceRestTransport + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[GenerativeServiceTransport]] +_transport_registry["grpc"] = GenerativeServiceGrpcTransport +_transport_registry["grpc_asyncio"] = GenerativeServiceGrpcAsyncIOTransport +_transport_registry["rest"] = GenerativeServiceRestTransport + +__all__ = ( + "GenerativeServiceTransport", + "GenerativeServiceGrpcTransport", + "GenerativeServiceGrpcAsyncIOTransport", + "GenerativeServiceRestTransport", + "GenerativeServiceRestInterceptor", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/base.py new file mode 100644 index 000000000000..f35e8e7c53df --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/base.py @@ -0,0 +1,298 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version +from google.ai.generativelanguage_v1alpha.types import generative_service + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +class GenerativeServiceTransport(abc.ABC): + """Abstract transport class for GenerativeService.""" + + AUTH_SCOPES = () + + DEFAULT_HOST: str = "generativelanguage.googleapis.com" + + def __init__( + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + if not hasattr(self, "_ignore_credentials"): + self._ignore_credentials: bool = False + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) + elif credentials is None and not self._ignore_credentials: + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + @property + def host(self): + return self._host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.generate_content: gapic_v1.method.wrap_method( + self.generate_content, + default_timeout=None, + client_info=client_info, + ), + self.generate_answer: gapic_v1.method.wrap_method( + self.generate_answer, + default_timeout=None, + client_info=client_info, + ), + self.stream_generate_content: gapic_v1.method.wrap_method( + self.stream_generate_content, + default_timeout=None, + client_info=client_info, + ), + self.embed_content: gapic_v1.method.wrap_method( + self.embed_content, + default_timeout=None, + client_info=client_info, + ), + self.batch_embed_contents: gapic_v1.method.wrap_method( + self.batch_embed_contents, + default_timeout=None, + client_info=client_info, + ), + self.count_tokens: gapic_v1.method.wrap_method( + self.count_tokens, + default_timeout=None, + client_info=client_info, + ), + self.bidi_generate_content: gapic_v1.method.wrap_method( + self.bidi_generate_content, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: gapic_v1.method.wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: gapic_v1.method.wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def generate_content( + self, + ) -> Callable[ + [generative_service.GenerateContentRequest], + Union[ + generative_service.GenerateContentResponse, + Awaitable[generative_service.GenerateContentResponse], + ], + ]: + raise NotImplementedError() + + @property + def generate_answer( + self, + ) -> Callable[ + [generative_service.GenerateAnswerRequest], + Union[ + generative_service.GenerateAnswerResponse, + Awaitable[generative_service.GenerateAnswerResponse], + ], + ]: + raise NotImplementedError() + + @property + def stream_generate_content( + self, + ) -> Callable[ + [generative_service.GenerateContentRequest], + Union[ + generative_service.GenerateContentResponse, + Awaitable[generative_service.GenerateContentResponse], + ], + ]: + raise NotImplementedError() + + @property + def embed_content( + self, + ) -> Callable[ + [generative_service.EmbedContentRequest], + Union[ + generative_service.EmbedContentResponse, + Awaitable[generative_service.EmbedContentResponse], + ], + ]: + raise NotImplementedError() + + @property + def batch_embed_contents( + self, + ) -> Callable[ + [generative_service.BatchEmbedContentsRequest], + Union[ + generative_service.BatchEmbedContentsResponse, + Awaitable[generative_service.BatchEmbedContentsResponse], + ], + ]: + raise NotImplementedError() + + @property + def count_tokens( + self, + ) -> Callable[ + [generative_service.CountTokensRequest], + Union[ + generative_service.CountTokensResponse, + Awaitable[generative_service.CountTokensResponse], + ], + ]: + raise NotImplementedError() + + @property + def bidi_generate_content( + self, + ) -> Callable[ + [generative_service.BidiGenerateContentClientMessage], + Union[ + generative_service.BidiGenerateContentServerMessage, + Awaitable[generative_service.BidiGenerateContentServerMessage], + ], + ]: + raise NotImplementedError() + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], + Union[ + operations_pb2.ListOperationsResponse, + Awaitable[operations_pb2.ListOperationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_operation( + self, + ) -> Callable[ + [operations_pb2.GetOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ("GenerativeServiceTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/grpc.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/grpc.py new file mode 100644 index 000000000000..e17edf1c6a1d --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/grpc.py @@ -0,0 +1,591 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import logging as std_logging +import pickle +from typing import Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import gapic_v1, grpc_helpers +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import generative_service + +from .base import DEFAULT_CLIENT_INFO, GenerativeServiceTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientInterceptor(grpc.UnaryUnaryClientInterceptor): # pragma: NO COVER + def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": client_call_details.method, + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + + response = continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = response.result() + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response for {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": client_call_details.method, + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class GenerativeServiceGrpcTransport(GenerativeServiceTransport): + """gRPC backend transport for GenerativeService. + + API for using Large Models that generate multimodal content + and have additional capabilities beyond text generation. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, grpc.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientInterceptor() + self._logged_channel = grpc.intercept_channel( + self._grpc_channel, self._interceptor + ) + + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service.""" + return self._grpc_channel + + @property + def generate_content( + self, + ) -> Callable[ + [generative_service.GenerateContentRequest], + generative_service.GenerateContentResponse, + ]: + r"""Return a callable for the generate content method over gRPC. + + Generates a model response given an input + ``GenerateContentRequest``. Refer to the `text generation + guide `__ + for detailed usage information. Input capabilities differ + between models, including tuned models. Refer to the `model + guide `__ + and `tuning + guide `__ + for details. + + Returns: + Callable[[~.GenerateContentRequest], + ~.GenerateContentResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "generate_content" not in self._stubs: + self._stubs["generate_content"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.GenerativeService/GenerateContent", + request_serializer=generative_service.GenerateContentRequest.serialize, + response_deserializer=generative_service.GenerateContentResponse.deserialize, + ) + return self._stubs["generate_content"] + + @property + def generate_answer( + self, + ) -> Callable[ + [generative_service.GenerateAnswerRequest], + generative_service.GenerateAnswerResponse, + ]: + r"""Return a callable for the generate answer method over gRPC. + + Generates a grounded answer from the model given an input + ``GenerateAnswerRequest``. + + Returns: + Callable[[~.GenerateAnswerRequest], + ~.GenerateAnswerResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "generate_answer" not in self._stubs: + self._stubs["generate_answer"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.GenerativeService/GenerateAnswer", + request_serializer=generative_service.GenerateAnswerRequest.serialize, + response_deserializer=generative_service.GenerateAnswerResponse.deserialize, + ) + return self._stubs["generate_answer"] + + @property + def stream_generate_content( + self, + ) -> Callable[ + [generative_service.GenerateContentRequest], + generative_service.GenerateContentResponse, + ]: + r"""Return a callable for the stream generate content method over gRPC. + + Generates a `streamed + response `__ + from the model given an input ``GenerateContentRequest``. + + Returns: + Callable[[~.GenerateContentRequest], + ~.GenerateContentResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "stream_generate_content" not in self._stubs: + self._stubs["stream_generate_content"] = self._logged_channel.unary_stream( + "/google.ai.generativelanguage.v1alpha.GenerativeService/StreamGenerateContent", + request_serializer=generative_service.GenerateContentRequest.serialize, + response_deserializer=generative_service.GenerateContentResponse.deserialize, + ) + return self._stubs["stream_generate_content"] + + @property + def embed_content( + self, + ) -> Callable[ + [generative_service.EmbedContentRequest], + generative_service.EmbedContentResponse, + ]: + r"""Return a callable for the embed content method over gRPC. + + Generates a text embedding vector from the input ``Content`` + using the specified `Gemini Embedding + model `__. + + Returns: + Callable[[~.EmbedContentRequest], + ~.EmbedContentResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "embed_content" not in self._stubs: + self._stubs["embed_content"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.GenerativeService/EmbedContent", + request_serializer=generative_service.EmbedContentRequest.serialize, + response_deserializer=generative_service.EmbedContentResponse.deserialize, + ) + return self._stubs["embed_content"] + + @property + def batch_embed_contents( + self, + ) -> Callable[ + [generative_service.BatchEmbedContentsRequest], + generative_service.BatchEmbedContentsResponse, + ]: + r"""Return a callable for the batch embed contents method over gRPC. + + Generates multiple embedding vectors from the input ``Content`` + which consists of a batch of strings represented as + ``EmbedContentRequest`` objects. + + Returns: + Callable[[~.BatchEmbedContentsRequest], + ~.BatchEmbedContentsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_embed_contents" not in self._stubs: + self._stubs["batch_embed_contents"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.GenerativeService/BatchEmbedContents", + request_serializer=generative_service.BatchEmbedContentsRequest.serialize, + response_deserializer=generative_service.BatchEmbedContentsResponse.deserialize, + ) + return self._stubs["batch_embed_contents"] + + @property + def count_tokens( + self, + ) -> Callable[ + [generative_service.CountTokensRequest], generative_service.CountTokensResponse + ]: + r"""Return a callable for the count tokens method over gRPC. + + Runs a model's tokenizer on input ``Content`` and returns the + token count. Refer to the `tokens + guide `__ to learn + more about tokens. + + Returns: + Callable[[~.CountTokensRequest], + ~.CountTokensResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "count_tokens" not in self._stubs: + self._stubs["count_tokens"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.GenerativeService/CountTokens", + request_serializer=generative_service.CountTokensRequest.serialize, + response_deserializer=generative_service.CountTokensResponse.deserialize, + ) + return self._stubs["count_tokens"] + + @property + def bidi_generate_content( + self, + ) -> Callable[ + [generative_service.BidiGenerateContentClientMessage], + generative_service.BidiGenerateContentServerMessage, + ]: + r"""Return a callable for the bidi generate content method over gRPC. + + Low-Latency bidirectional streaming API that supports + audio and video streaming inputs can produce multimodal + output streams (audio and text). + + Returns: + Callable[[~.BidiGenerateContentClientMessage], + ~.BidiGenerateContentServerMessage]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "bidi_generate_content" not in self._stubs: + self._stubs["bidi_generate_content"] = self._logged_channel.stream_stream( + "/google.ai.generativelanguage.v1alpha.GenerativeService/BidiGenerateContent", + request_serializer=generative_service.BidiGenerateContentClientMessage.serialize, + response_deserializer=generative_service.BidiGenerateContentServerMessage.deserialize, + ) + return self._stubs["bidi_generate_content"] + + def close(self): + self._logged_channel.close() + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ("GenerativeServiceGrpcTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/grpc_asyncio.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..584719dd4572 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/grpc_asyncio.py @@ -0,0 +1,654 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import json +import logging as std_logging +import pickle +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, grpc_helpers_async +from google.api_core import retry_async as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +from grpc.experimental import aio # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import generative_service + +from .base import DEFAULT_CLIENT_INFO, GenerativeServiceTransport +from .grpc import GenerativeServiceGrpcTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientAIOInterceptor( + grpc.aio.UnaryUnaryClientInterceptor +): # pragma: NO COVER + async def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": str(client_call_details.method), + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + response = await continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = await response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = await response + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response to rpc {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": str(client_call_details.method), + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class GenerativeServiceGrpcAsyncIOTransport(GenerativeServiceTransport): + """gRPC AsyncIO backend transport for GenerativeService. + + API for using Large Models that generate multimodal content + and have additional capabilities beyond text generation. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, aio.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientAIOInterceptor() + self._grpc_channel._unary_unary_interceptors.append(self._interceptor) + self._logged_channel = self._grpc_channel + self._wrap_with_kind = ( + "kind" in inspect.signature(gapic_v1.method_async.wrap_method).parameters + ) + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def generate_content( + self, + ) -> Callable[ + [generative_service.GenerateContentRequest], + Awaitable[generative_service.GenerateContentResponse], + ]: + r"""Return a callable for the generate content method over gRPC. + + Generates a model response given an input + ``GenerateContentRequest``. Refer to the `text generation + guide `__ + for detailed usage information. Input capabilities differ + between models, including tuned models. Refer to the `model + guide `__ + and `tuning + guide `__ + for details. + + Returns: + Callable[[~.GenerateContentRequest], + Awaitable[~.GenerateContentResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "generate_content" not in self._stubs: + self._stubs["generate_content"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.GenerativeService/GenerateContent", + request_serializer=generative_service.GenerateContentRequest.serialize, + response_deserializer=generative_service.GenerateContentResponse.deserialize, + ) + return self._stubs["generate_content"] + + @property + def generate_answer( + self, + ) -> Callable[ + [generative_service.GenerateAnswerRequest], + Awaitable[generative_service.GenerateAnswerResponse], + ]: + r"""Return a callable for the generate answer method over gRPC. + + Generates a grounded answer from the model given an input + ``GenerateAnswerRequest``. + + Returns: + Callable[[~.GenerateAnswerRequest], + Awaitable[~.GenerateAnswerResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "generate_answer" not in self._stubs: + self._stubs["generate_answer"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.GenerativeService/GenerateAnswer", + request_serializer=generative_service.GenerateAnswerRequest.serialize, + response_deserializer=generative_service.GenerateAnswerResponse.deserialize, + ) + return self._stubs["generate_answer"] + + @property + def stream_generate_content( + self, + ) -> Callable[ + [generative_service.GenerateContentRequest], + Awaitable[generative_service.GenerateContentResponse], + ]: + r"""Return a callable for the stream generate content method over gRPC. + + Generates a `streamed + response `__ + from the model given an input ``GenerateContentRequest``. + + Returns: + Callable[[~.GenerateContentRequest], + Awaitable[~.GenerateContentResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "stream_generate_content" not in self._stubs: + self._stubs["stream_generate_content"] = self._logged_channel.unary_stream( + "/google.ai.generativelanguage.v1alpha.GenerativeService/StreamGenerateContent", + request_serializer=generative_service.GenerateContentRequest.serialize, + response_deserializer=generative_service.GenerateContentResponse.deserialize, + ) + return self._stubs["stream_generate_content"] + + @property + def embed_content( + self, + ) -> Callable[ + [generative_service.EmbedContentRequest], + Awaitable[generative_service.EmbedContentResponse], + ]: + r"""Return a callable for the embed content method over gRPC. + + Generates a text embedding vector from the input ``Content`` + using the specified `Gemini Embedding + model `__. + + Returns: + Callable[[~.EmbedContentRequest], + Awaitable[~.EmbedContentResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "embed_content" not in self._stubs: + self._stubs["embed_content"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.GenerativeService/EmbedContent", + request_serializer=generative_service.EmbedContentRequest.serialize, + response_deserializer=generative_service.EmbedContentResponse.deserialize, + ) + return self._stubs["embed_content"] + + @property + def batch_embed_contents( + self, + ) -> Callable[ + [generative_service.BatchEmbedContentsRequest], + Awaitable[generative_service.BatchEmbedContentsResponse], + ]: + r"""Return a callable for the batch embed contents method over gRPC. + + Generates multiple embedding vectors from the input ``Content`` + which consists of a batch of strings represented as + ``EmbedContentRequest`` objects. + + Returns: + Callable[[~.BatchEmbedContentsRequest], + Awaitable[~.BatchEmbedContentsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_embed_contents" not in self._stubs: + self._stubs["batch_embed_contents"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.GenerativeService/BatchEmbedContents", + request_serializer=generative_service.BatchEmbedContentsRequest.serialize, + response_deserializer=generative_service.BatchEmbedContentsResponse.deserialize, + ) + return self._stubs["batch_embed_contents"] + + @property + def count_tokens( + self, + ) -> Callable[ + [generative_service.CountTokensRequest], + Awaitable[generative_service.CountTokensResponse], + ]: + r"""Return a callable for the count tokens method over gRPC. + + Runs a model's tokenizer on input ``Content`` and returns the + token count. Refer to the `tokens + guide `__ to learn + more about tokens. + + Returns: + Callable[[~.CountTokensRequest], + Awaitable[~.CountTokensResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "count_tokens" not in self._stubs: + self._stubs["count_tokens"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.GenerativeService/CountTokens", + request_serializer=generative_service.CountTokensRequest.serialize, + response_deserializer=generative_service.CountTokensResponse.deserialize, + ) + return self._stubs["count_tokens"] + + @property + def bidi_generate_content( + self, + ) -> Callable[ + [generative_service.BidiGenerateContentClientMessage], + Awaitable[generative_service.BidiGenerateContentServerMessage], + ]: + r"""Return a callable for the bidi generate content method over gRPC. + + Low-Latency bidirectional streaming API that supports + audio and video streaming inputs can produce multimodal + output streams (audio and text). + + Returns: + Callable[[~.BidiGenerateContentClientMessage], + Awaitable[~.BidiGenerateContentServerMessage]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "bidi_generate_content" not in self._stubs: + self._stubs["bidi_generate_content"] = self._logged_channel.stream_stream( + "/google.ai.generativelanguage.v1alpha.GenerativeService/BidiGenerateContent", + request_serializer=generative_service.BidiGenerateContentClientMessage.serialize, + response_deserializer=generative_service.BidiGenerateContentServerMessage.deserialize, + ) + return self._stubs["bidi_generate_content"] + + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.generate_content: self._wrap_method( + self.generate_content, + default_timeout=None, + client_info=client_info, + ), + self.generate_answer: self._wrap_method( + self.generate_answer, + default_timeout=None, + client_info=client_info, + ), + self.stream_generate_content: self._wrap_method( + self.stream_generate_content, + default_timeout=None, + client_info=client_info, + ), + self.embed_content: self._wrap_method( + self.embed_content, + default_timeout=None, + client_info=client_info, + ), + self.batch_embed_contents: self._wrap_method( + self.batch_embed_contents, + default_timeout=None, + client_info=client_info, + ), + self.count_tokens: self._wrap_method( + self.count_tokens, + default_timeout=None, + client_info=client_info, + ), + self.bidi_generate_content: self._wrap_method( + self.bidi_generate_content, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: self._wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: self._wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def _wrap_method(self, func, *args, **kwargs): + if self._wrap_with_kind: # pragma: NO COVER + kwargs["kind"] = self.kind + return gapic_v1.method_async.wrap_method(func, *args, **kwargs) + + def close(self): + return self._logged_channel.close() + + @property + def kind(self) -> str: + return "grpc_asyncio" + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + +__all__ = ("GenerativeServiceGrpcAsyncIOTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/rest.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/rest.py new file mode 100644 index 000000000000..0c4361c7704c --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/rest.py @@ -0,0 +1,1726 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +import json # type: ignore +import logging +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, rest_helpers, rest_streaming +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.requests import AuthorizedSession # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import json_format +from requests import __version__ as requests_version + +from google.ai.generativelanguage_v1alpha.types import generative_service + +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from .rest_base import _BaseGenerativeServiceRestTransport + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = logging.getLogger(__name__) + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=f"requests@{requests_version}", +) + + +class GenerativeServiceRestInterceptor: + """Interceptor for GenerativeService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the GenerativeServiceRestTransport. + + .. code-block:: python + class MyCustomGenerativeServiceInterceptor(GenerativeServiceRestInterceptor): + def pre_batch_embed_contents(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_batch_embed_contents(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_count_tokens(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_count_tokens(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_embed_content(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_embed_content(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_generate_answer(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_generate_answer(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_generate_content(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_generate_content(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_stream_generate_content(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_stream_generate_content(self, response): + logging.log(f"Received response: {response}") + return response + + transport = GenerativeServiceRestTransport(interceptor=MyCustomGenerativeServiceInterceptor()) + client = GenerativeServiceClient(transport=transport) + + + """ + + def pre_batch_embed_contents( + self, + request: generative_service.BatchEmbedContentsRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + generative_service.BatchEmbedContentsRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for batch_embed_contents + + Override in a subclass to manipulate the request or metadata + before they are sent to the GenerativeService server. + """ + return request, metadata + + def post_batch_embed_contents( + self, response: generative_service.BatchEmbedContentsResponse + ) -> generative_service.BatchEmbedContentsResponse: + """Post-rpc interceptor for batch_embed_contents + + Override in a subclass to manipulate the response + after it is returned by the GenerativeService server but before + it is returned to user code. + """ + return response + + def pre_count_tokens( + self, + request: generative_service.CountTokensRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + generative_service.CountTokensRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for count_tokens + + Override in a subclass to manipulate the request or metadata + before they are sent to the GenerativeService server. + """ + return request, metadata + + def post_count_tokens( + self, response: generative_service.CountTokensResponse + ) -> generative_service.CountTokensResponse: + """Post-rpc interceptor for count_tokens + + Override in a subclass to manipulate the response + after it is returned by the GenerativeService server but before + it is returned to user code. + """ + return response + + def pre_embed_content( + self, + request: generative_service.EmbedContentRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + generative_service.EmbedContentRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for embed_content + + Override in a subclass to manipulate the request or metadata + before they are sent to the GenerativeService server. + """ + return request, metadata + + def post_embed_content( + self, response: generative_service.EmbedContentResponse + ) -> generative_service.EmbedContentResponse: + """Post-rpc interceptor for embed_content + + Override in a subclass to manipulate the response + after it is returned by the GenerativeService server but before + it is returned to user code. + """ + return response + + def pre_generate_answer( + self, + request: generative_service.GenerateAnswerRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + generative_service.GenerateAnswerRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for generate_answer + + Override in a subclass to manipulate the request or metadata + before they are sent to the GenerativeService server. + """ + return request, metadata + + def post_generate_answer( + self, response: generative_service.GenerateAnswerResponse + ) -> generative_service.GenerateAnswerResponse: + """Post-rpc interceptor for generate_answer + + Override in a subclass to manipulate the response + after it is returned by the GenerativeService server but before + it is returned to user code. + """ + return response + + def pre_generate_content( + self, + request: generative_service.GenerateContentRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + generative_service.GenerateContentRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for generate_content + + Override in a subclass to manipulate the request or metadata + before they are sent to the GenerativeService server. + """ + return request, metadata + + def post_generate_content( + self, response: generative_service.GenerateContentResponse + ) -> generative_service.GenerateContentResponse: + """Post-rpc interceptor for generate_content + + Override in a subclass to manipulate the response + after it is returned by the GenerativeService server but before + it is returned to user code. + """ + return response + + def pre_stream_generate_content( + self, + request: generative_service.GenerateContentRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + generative_service.GenerateContentRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for stream_generate_content + + Override in a subclass to manipulate the request or metadata + before they are sent to the GenerativeService server. + """ + return request, metadata + + def post_stream_generate_content( + self, response: rest_streaming.ResponseIterator + ) -> rest_streaming.ResponseIterator: + """Post-rpc interceptor for stream_generate_content + + Override in a subclass to manipulate the response + after it is returned by the GenerativeService server but before + it is returned to user code. + """ + return response + + def pre_get_operation( + self, + request: operations_pb2.GetOperationRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.GetOperationRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for get_operation + + Override in a subclass to manipulate the request or metadata + before they are sent to the GenerativeService server. + """ + return request, metadata + + def post_get_operation( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for get_operation + + Override in a subclass to manipulate the response + after it is returned by the GenerativeService server but before + it is returned to user code. + """ + return response + + def pre_list_operations( + self, + request: operations_pb2.ListOperationsRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.ListOperationsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for list_operations + + Override in a subclass to manipulate the request or metadata + before they are sent to the GenerativeService server. + """ + return request, metadata + + def post_list_operations( + self, response: operations_pb2.ListOperationsResponse + ) -> operations_pb2.ListOperationsResponse: + """Post-rpc interceptor for list_operations + + Override in a subclass to manipulate the response + after it is returned by the GenerativeService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class GenerativeServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: GenerativeServiceRestInterceptor + + +class GenerativeServiceRestTransport(_BaseGenerativeServiceRestTransport): + """REST backend synchronous transport for GenerativeService. + + API for using Large Models that generate multimodal content + and have additional capabilities beyond text generation. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[GenerativeServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + url_scheme=url_scheme, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or GenerativeServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _BatchEmbedContents( + _BaseGenerativeServiceRestTransport._BaseBatchEmbedContents, + GenerativeServiceRestStub, + ): + def __hash__(self): + return hash("GenerativeServiceRestTransport.BatchEmbedContents") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: generative_service.BatchEmbedContentsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> generative_service.BatchEmbedContentsResponse: + r"""Call the batch embed contents method over HTTP. + + Args: + request (~.generative_service.BatchEmbedContentsRequest): + The request object. Batch request to get embeddings from + the model for a list of prompts. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.generative_service.BatchEmbedContentsResponse: + The response to a ``BatchEmbedContentsRequest``. + """ + + http_options = ( + _BaseGenerativeServiceRestTransport._BaseBatchEmbedContents._get_http_options() + ) + + request, metadata = self._interceptor.pre_batch_embed_contents( + request, metadata + ) + transcoded_request = _BaseGenerativeServiceRestTransport._BaseBatchEmbedContents._get_transcoded_request( + http_options, request + ) + + body = _BaseGenerativeServiceRestTransport._BaseBatchEmbedContents._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseGenerativeServiceRestTransport._BaseBatchEmbedContents._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.GenerativeServiceClient.BatchEmbedContents", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": "BatchEmbedContents", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = GenerativeServiceRestTransport._BatchEmbedContents._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = generative_service.BatchEmbedContentsResponse() + pb_resp = generative_service.BatchEmbedContentsResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_batch_embed_contents(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = ( + generative_service.BatchEmbedContentsResponse.to_json(response) + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.GenerativeServiceClient.batch_embed_contents", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": "BatchEmbedContents", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _BidiGenerateContent( + _BaseGenerativeServiceRestTransport._BaseBidiGenerateContent, + GenerativeServiceRestStub, + ): + def __hash__(self): + return hash("GenerativeServiceRestTransport.BidiGenerateContent") + + def __call__( + self, + request: generative_service.BidiGenerateContentClientMessage, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> rest_streaming.ResponseIterator: + raise NotImplementedError( + "Method BidiGenerateContent is not available over REST transport" + ) + + class _CountTokens( + _BaseGenerativeServiceRestTransport._BaseCountTokens, GenerativeServiceRestStub + ): + def __hash__(self): + return hash("GenerativeServiceRestTransport.CountTokens") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: generative_service.CountTokensRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> generative_service.CountTokensResponse: + r"""Call the count tokens method over HTTP. + + Args: + request (~.generative_service.CountTokensRequest): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.generative_service.CountTokensResponse: + A response from ``CountTokens``. + + It returns the model's ``token_count`` for the + ``prompt``. + + """ + + http_options = ( + _BaseGenerativeServiceRestTransport._BaseCountTokens._get_http_options() + ) + + request, metadata = self._interceptor.pre_count_tokens(request, metadata) + transcoded_request = _BaseGenerativeServiceRestTransport._BaseCountTokens._get_transcoded_request( + http_options, request + ) + + body = _BaseGenerativeServiceRestTransport._BaseCountTokens._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseGenerativeServiceRestTransport._BaseCountTokens._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.GenerativeServiceClient.CountTokens", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": "CountTokens", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = GenerativeServiceRestTransport._CountTokens._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = generative_service.CountTokensResponse() + pb_resp = generative_service.CountTokensResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_count_tokens(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = generative_service.CountTokensResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.GenerativeServiceClient.count_tokens", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": "CountTokens", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _EmbedContent( + _BaseGenerativeServiceRestTransport._BaseEmbedContent, GenerativeServiceRestStub + ): + def __hash__(self): + return hash("GenerativeServiceRestTransport.EmbedContent") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: generative_service.EmbedContentRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> generative_service.EmbedContentResponse: + r"""Call the embed content method over HTTP. + + Args: + request (~.generative_service.EmbedContentRequest): + The request object. Request containing the ``Content`` for the model to + embed. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.generative_service.EmbedContentResponse: + The response to an ``EmbedContentRequest``. + """ + + http_options = ( + _BaseGenerativeServiceRestTransport._BaseEmbedContent._get_http_options() + ) + + request, metadata = self._interceptor.pre_embed_content(request, metadata) + transcoded_request = _BaseGenerativeServiceRestTransport._BaseEmbedContent._get_transcoded_request( + http_options, request + ) + + body = _BaseGenerativeServiceRestTransport._BaseEmbedContent._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseGenerativeServiceRestTransport._BaseEmbedContent._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.GenerativeServiceClient.EmbedContent", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": "EmbedContent", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = GenerativeServiceRestTransport._EmbedContent._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = generative_service.EmbedContentResponse() + pb_resp = generative_service.EmbedContentResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_embed_content(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = generative_service.EmbedContentResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.GenerativeServiceClient.embed_content", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": "EmbedContent", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _GenerateAnswer( + _BaseGenerativeServiceRestTransport._BaseGenerateAnswer, + GenerativeServiceRestStub, + ): + def __hash__(self): + return hash("GenerativeServiceRestTransport.GenerateAnswer") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: generative_service.GenerateAnswerRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> generative_service.GenerateAnswerResponse: + r"""Call the generate answer method over HTTP. + + Args: + request (~.generative_service.GenerateAnswerRequest): + The request object. Request to generate a grounded answer from the + ``Model``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.generative_service.GenerateAnswerResponse: + Response from the model for a + grounded answer. + + """ + + http_options = ( + _BaseGenerativeServiceRestTransport._BaseGenerateAnswer._get_http_options() + ) + + request, metadata = self._interceptor.pre_generate_answer(request, metadata) + transcoded_request = _BaseGenerativeServiceRestTransport._BaseGenerateAnswer._get_transcoded_request( + http_options, request + ) + + body = _BaseGenerativeServiceRestTransport._BaseGenerateAnswer._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseGenerativeServiceRestTransport._BaseGenerateAnswer._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.GenerativeServiceClient.GenerateAnswer", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": "GenerateAnswer", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = GenerativeServiceRestTransport._GenerateAnswer._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = generative_service.GenerateAnswerResponse() + pb_resp = generative_service.GenerateAnswerResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_generate_answer(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = ( + generative_service.GenerateAnswerResponse.to_json(response) + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.GenerativeServiceClient.generate_answer", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": "GenerateAnswer", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _GenerateContent( + _BaseGenerativeServiceRestTransport._BaseGenerateContent, + GenerativeServiceRestStub, + ): + def __hash__(self): + return hash("GenerativeServiceRestTransport.GenerateContent") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: generative_service.GenerateContentRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> generative_service.GenerateContentResponse: + r"""Call the generate content method over HTTP. + + Args: + request (~.generative_service.GenerateContentRequest): + The request object. Request to generate a completion from + the model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.generative_service.GenerateContentResponse: + Response from the model supporting multiple candidate + responses. + + Safety ratings and content filtering are reported for + both prompt in + ``GenerateContentResponse.prompt_feedback`` and for each + candidate in ``finish_reason`` and in + ``safety_ratings``. The API: + + - Returns either all requested candidates or none of + them + - Returns no candidates at all only if there was + something wrong with the prompt (check + ``prompt_feedback``) + - Reports feedback on each candidate in + ``finish_reason`` and ``safety_ratings``. + + """ + + http_options = ( + _BaseGenerativeServiceRestTransport._BaseGenerateContent._get_http_options() + ) + + request, metadata = self._interceptor.pre_generate_content( + request, metadata + ) + transcoded_request = _BaseGenerativeServiceRestTransport._BaseGenerateContent._get_transcoded_request( + http_options, request + ) + + body = _BaseGenerativeServiceRestTransport._BaseGenerateContent._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseGenerativeServiceRestTransport._BaseGenerateContent._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.GenerativeServiceClient.GenerateContent", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": "GenerateContent", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = GenerativeServiceRestTransport._GenerateContent._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = generative_service.GenerateContentResponse() + pb_resp = generative_service.GenerateContentResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_generate_content(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = ( + generative_service.GenerateContentResponse.to_json(response) + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.GenerativeServiceClient.generate_content", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": "GenerateContent", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _StreamGenerateContent( + _BaseGenerativeServiceRestTransport._BaseStreamGenerateContent, + GenerativeServiceRestStub, + ): + def __hash__(self): + return hash("GenerativeServiceRestTransport.StreamGenerateContent") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + stream=True, + ) + return response + + def __call__( + self, + request: generative_service.GenerateContentRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> rest_streaming.ResponseIterator: + r"""Call the stream generate content method over HTTP. + + Args: + request (~.generative_service.GenerateContentRequest): + The request object. Request to generate a completion from + the model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.generative_service.GenerateContentResponse: + Response from the model supporting multiple candidate + responses. + + Safety ratings and content filtering are reported for + both prompt in + ``GenerateContentResponse.prompt_feedback`` and for each + candidate in ``finish_reason`` and in + ``safety_ratings``. The API: + + - Returns either all requested candidates or none of + them + - Returns no candidates at all only if there was + something wrong with the prompt (check + ``prompt_feedback``) + - Reports feedback on each candidate in + ``finish_reason`` and ``safety_ratings``. + + """ + + http_options = ( + _BaseGenerativeServiceRestTransport._BaseStreamGenerateContent._get_http_options() + ) + + request, metadata = self._interceptor.pre_stream_generate_content( + request, metadata + ) + transcoded_request = _BaseGenerativeServiceRestTransport._BaseStreamGenerateContent._get_transcoded_request( + http_options, request + ) + + body = _BaseGenerativeServiceRestTransport._BaseStreamGenerateContent._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseGenerativeServiceRestTransport._BaseStreamGenerateContent._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.GenerativeServiceClient.StreamGenerateContent", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": "StreamGenerateContent", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = ( + GenerativeServiceRestTransport._StreamGenerateContent._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = rest_streaming.ResponseIterator( + response, generative_service.GenerateContentResponse + ) + + resp = self._interceptor.post_stream_generate_content(resp) + return resp + + @property + def batch_embed_contents( + self, + ) -> Callable[ + [generative_service.BatchEmbedContentsRequest], + generative_service.BatchEmbedContentsResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._BatchEmbedContents(self._session, self._host, self._interceptor) # type: ignore + + @property + def bidi_generate_content( + self, + ) -> Callable[ + [generative_service.BidiGenerateContentClientMessage], + generative_service.BidiGenerateContentServerMessage, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._BidiGenerateContent(self._session, self._host, self._interceptor) # type: ignore + + @property + def count_tokens( + self, + ) -> Callable[ + [generative_service.CountTokensRequest], generative_service.CountTokensResponse + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CountTokens(self._session, self._host, self._interceptor) # type: ignore + + @property + def embed_content( + self, + ) -> Callable[ + [generative_service.EmbedContentRequest], + generative_service.EmbedContentResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._EmbedContent(self._session, self._host, self._interceptor) # type: ignore + + @property + def generate_answer( + self, + ) -> Callable[ + [generative_service.GenerateAnswerRequest], + generative_service.GenerateAnswerResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GenerateAnswer(self._session, self._host, self._interceptor) # type: ignore + + @property + def generate_content( + self, + ) -> Callable[ + [generative_service.GenerateContentRequest], + generative_service.GenerateContentResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GenerateContent(self._session, self._host, self._interceptor) # type: ignore + + @property + def stream_generate_content( + self, + ) -> Callable[ + [generative_service.GenerateContentRequest], + generative_service.GenerateContentResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._StreamGenerateContent(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_operation(self): + return self._GetOperation(self._session, self._host, self._interceptor) # type: ignore + + class _GetOperation( + _BaseGenerativeServiceRestTransport._BaseGetOperation, GenerativeServiceRestStub + ): + def __hash__(self): + return hash("GenerativeServiceRestTransport.GetOperation") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.GetOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Call the get operation method over HTTP. + + Args: + request (operations_pb2.GetOperationRequest): + The request object for GetOperation method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.Operation: Response from GetOperation method. + """ + + http_options = ( + _BaseGenerativeServiceRestTransport._BaseGetOperation._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_operation(request, metadata) + transcoded_request = _BaseGenerativeServiceRestTransport._BaseGetOperation._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseGenerativeServiceRestTransport._BaseGetOperation._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.GenerativeServiceClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": "GetOperation", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = GenerativeServiceRestTransport._GetOperation._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.Operation() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_get_operation(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": "GetOperation", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def list_operations(self): + return self._ListOperations(self._session, self._host, self._interceptor) # type: ignore + + class _ListOperations( + _BaseGenerativeServiceRestTransport._BaseListOperations, + GenerativeServiceRestStub, + ): + def __hash__(self): + return hash("GenerativeServiceRestTransport.ListOperations") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.ListOperationsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Call the list operations method over HTTP. + + Args: + request (operations_pb2.ListOperationsRequest): + The request object for ListOperations method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.ListOperationsResponse: Response from ListOperations method. + """ + + http_options = ( + _BaseGenerativeServiceRestTransport._BaseListOperations._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_operations(request, metadata) + transcoded_request = _BaseGenerativeServiceRestTransport._BaseListOperations._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseGenerativeServiceRestTransport._BaseListOperations._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.GenerativeServiceClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": "ListOperations", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = GenerativeServiceRestTransport._ListOperations._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.ListOperationsResponse() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_list_operations(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "rpcName": "ListOperations", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("GenerativeServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/rest_base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/rest_base.py new file mode 100644 index 000000000000..f622ef84db0d --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/generative_service/transports/rest_base.py @@ -0,0 +1,510 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json # type: ignore +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1, path_template +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import json_format + +from google.ai.generativelanguage_v1alpha.types import generative_service + +from .base import DEFAULT_CLIENT_INFO, GenerativeServiceTransport + + +class _BaseGenerativeServiceRestTransport(GenerativeServiceTransport): + """Base REST backend transport for GenerativeService. + + Note: This class is not meant to be used directly. Use its sync and + async sub-classes instead. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[Any] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[Any]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + class _BaseBatchEmbedContents: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{model=models/*}:batchEmbedContents", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = generative_service.BatchEmbedContentsRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseGenerativeServiceRestTransport._BaseBatchEmbedContents._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseBidiGenerateContent: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + class _BaseCountTokens: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{model=models/*}:countTokens", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = generative_service.CountTokensRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseGenerativeServiceRestTransport._BaseCountTokens._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseEmbedContent: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{model=models/*}:embedContent", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = generative_service.EmbedContentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseGenerativeServiceRestTransport._BaseEmbedContent._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGenerateAnswer: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{model=models/*}:generateAnswer", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = generative_service.GenerateAnswerRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseGenerativeServiceRestTransport._BaseGenerateAnswer._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGenerateContent: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{model=models/*}:generateContent", + "body": "*", + }, + { + "method": "post", + "uri": "/v1alpha/{model=tunedModels/*}:generateContent", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = generative_service.GenerateContentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseGenerativeServiceRestTransport._BaseGenerateContent._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseStreamGenerateContent: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{model=models/*}:streamGenerateContent", + "body": "*", + }, + { + "method": "post", + "uri": "/v1alpha/{model=tunedModels/*}:streamGenerateContent", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = generative_service.GenerateContentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseGenerativeServiceRestTransport._BaseStreamGenerateContent._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetOperation: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=generatedFiles/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*/operations/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + class _BaseListOperations: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*}/operations", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*}/operations", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + +__all__ = ("_BaseGenerativeServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/__init__.py new file mode 100644 index 000000000000..86afc5d9278d --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .async_client import ModelServiceAsyncClient +from .client import ModelServiceClient + +__all__ = ( + "ModelServiceClient", + "ModelServiceAsyncClient", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/async_client.py new file mode 100644 index 000000000000..1844d64aea7f --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/async_client.py @@ -0,0 +1,1261 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry_async as retries +from google.api_core.client_options import ClientOptions +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore + +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.model_service import pagers +from google.ai.generativelanguage_v1alpha.types import tuned_model as gag_tuned_model +from google.ai.generativelanguage_v1alpha.types import model, model_service +from google.ai.generativelanguage_v1alpha.types import tuned_model + +from .client import ModelServiceClient +from .transports.base import DEFAULT_CLIENT_INFO, ModelServiceTransport +from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class ModelServiceAsyncClient: + """Provides methods for getting metadata information about + Generative Models. + """ + + _client: ModelServiceClient + + # Copy defaults from the synchronous client for use here. + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = ModelServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = ModelServiceClient.DEFAULT_MTLS_ENDPOINT + _DEFAULT_ENDPOINT_TEMPLATE = ModelServiceClient._DEFAULT_ENDPOINT_TEMPLATE + _DEFAULT_UNIVERSE = ModelServiceClient._DEFAULT_UNIVERSE + + model_path = staticmethod(ModelServiceClient.model_path) + parse_model_path = staticmethod(ModelServiceClient.parse_model_path) + tuned_model_path = staticmethod(ModelServiceClient.tuned_model_path) + parse_tuned_model_path = staticmethod(ModelServiceClient.parse_tuned_model_path) + common_billing_account_path = staticmethod( + ModelServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + ModelServiceClient.parse_common_billing_account_path + ) + common_folder_path = staticmethod(ModelServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(ModelServiceClient.parse_common_folder_path) + common_organization_path = staticmethod(ModelServiceClient.common_organization_path) + parse_common_organization_path = staticmethod( + ModelServiceClient.parse_common_organization_path + ) + common_project_path = staticmethod(ModelServiceClient.common_project_path) + parse_common_project_path = staticmethod( + ModelServiceClient.parse_common_project_path + ) + common_location_path = staticmethod(ModelServiceClient.common_location_path) + parse_common_location_path = staticmethod( + ModelServiceClient.parse_common_location_path + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceAsyncClient: The constructed client. + """ + return ModelServiceClient.from_service_account_info.__func__(ModelServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceAsyncClient: The constructed client. + """ + return ModelServiceClient.from_service_account_file.__func__(ModelServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return ModelServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> ModelServiceTransport: + """Returns the transport used by the client instance. + + Returns: + ModelServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._client._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used + by the client instance. + """ + return self._client._universe_domain + + get_transport_class = ModelServiceClient.get_transport_class + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[str, ModelServiceTransport, Callable[..., ModelServiceTransport]] + ] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the model service async client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,ModelServiceTransport,Callable[..., ModelServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ModelServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = ModelServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "universeDomain": getattr( + self._client._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._client._transport._credentials).__module__}.{type(self._client._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._client._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "credentialsType": None, + }, + ) + + async def get_model( + self, + request: Optional[Union[model_service.GetModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> model.Model: + r"""Gets information about a specific ``Model`` such as its version + number, token limits, + `parameters `__ + and other metadata. Refer to the `Gemini models + guide `__ + for detailed model information. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_get_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetModelRequest( + name="name_value", + ) + + # Make the request + response = await client.get_model(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.GetModelRequest, dict]]): + The request object. Request for getting information about + a specific Model. + name (:class:`str`): + Required. The resource name of the model. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Model: + Information about a Generative + Language Model. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.GetModelRequest): + request = model_service.GetModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_model + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_models( + self, + request: Optional[Union[model_service.ListModelsRequest, dict]] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListModelsAsyncPager: + r"""Lists the + ```Model``\ s `__ + available through the Gemini API. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_list_models(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListModelsRequest( + ) + + # Make the request + page_result = client.list_models(request=request) + + # Handle the response + async for response in page_result: + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.ListModelsRequest, dict]]): + The request object. Request for listing all Models. + page_size (:class:`int`): + The maximum number of ``Models`` to return (per page). + + If unspecified, 50 models will be returned per page. + This method returns at most 1000 models per page, even + if you pass a larger page_size. + + This corresponds to the ``page_size`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + page_token (:class:`str`): + A page token, received from a previous ``ListModels`` + call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListModels`` must match the call that provided the + page token. + + This corresponds to the ``page_token`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.model_service.pagers.ListModelsAsyncPager: + Response from ListModel containing a paginated list of + Models. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([page_size, page_token]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ListModelsRequest): + request = model_service.ListModelsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if page_size is not None: + request.page_size = page_size + if page_token is not None: + request.page_token = page_token + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_models + ] + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListModelsAsyncPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_tuned_model( + self, + request: Optional[Union[model_service.GetTunedModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> tuned_model.TunedModel: + r"""Gets information about a specific TunedModel. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_get_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetTunedModelRequest( + name="name_value", + ) + + # Make the request + response = await client.get_tuned_model(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.GetTunedModelRequest, dict]]): + The request object. Request for getting information about + a specific Model. + name (:class:`str`): + Required. The resource name of the model. + + Format: ``tunedModels/my-model-id`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.TunedModel: + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.GetTunedModelRequest): + request = model_service.GetTunedModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_tuned_model + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_tuned_models( + self, + request: Optional[Union[model_service.ListTunedModelsRequest, dict]] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListTunedModelsAsyncPager: + r"""Lists created tuned models. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_list_tuned_models(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListTunedModelsRequest( + ) + + # Make the request + page_result = client.list_tuned_models(request=request) + + # Handle the response + async for response in page_result: + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.ListTunedModelsRequest, dict]]): + The request object. Request for listing TunedModels. + page_size (:class:`int`): + Optional. The maximum number of ``TunedModels`` to + return (per page). The service may return fewer tuned + models. + + If unspecified, at most 10 tuned models will be + returned. This method returns at most 1000 models per + page, even if you pass a larger page_size. + + This corresponds to the ``page_size`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + page_token (:class:`str`): + Optional. A page token, received from a previous + ``ListTunedModels`` call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListTunedModels`` must match the call that provided + the page token. + + This corresponds to the ``page_token`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.model_service.pagers.ListTunedModelsAsyncPager: + Response from ListTunedModels containing a paginated + list of Models. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([page_size, page_token]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ListTunedModelsRequest): + request = model_service.ListTunedModelsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if page_size is not None: + request.page_size = page_size + if page_token is not None: + request.page_token = page_token + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_tuned_models + ] + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListTunedModelsAsyncPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def create_tuned_model( + self, + request: Optional[Union[model_service.CreateTunedModelRequest, dict]] = None, + *, + tuned_model: Optional[gag_tuned_model.TunedModel] = None, + tuned_model_id: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operation_async.AsyncOperation: + r"""Creates a tuned model. Check intermediate tuning progress (if + any) through the [google.longrunning.Operations] service. + + Access status and results through the Operations service. + Example: GET /v1/tunedModels/az2mb0bpw6i/operations/000-111-222 + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_create_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateTunedModelRequest( + ) + + # Make the request + operation = client.create_tuned_model(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.CreateTunedModelRequest, dict]]): + The request object. Request to create a TunedModel. + tuned_model (:class:`google.ai.generativelanguage_v1alpha.types.TunedModel`): + Required. The tuned model to create. + This corresponds to the ``tuned_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tuned_model_id (:class:`str`): + Optional. The unique id for the tuned model if + specified. This value should be up to 40 characters, the + first character must be a letter, the last could be a + letter or a number. The id must match the regular + expression: ``[a-z]([a-z0-9-]{0,38}[a-z0-9])?``. + + This corresponds to the ``tuned_model_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.api_core.operation_async.AsyncOperation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.ai.generativelanguage_v1alpha.types.TunedModel` + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tuned_model, tuned_model_id]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.CreateTunedModelRequest): + request = model_service.CreateTunedModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if tuned_model is not None: + request.tuned_model = tuned_model + if tuned_model_id is not None: + request.tuned_model_id = tuned_model_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_tuned_model + ] + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation_async.from_gapic( + response, + self._client._transport.operations_client, + gag_tuned_model.TunedModel, + metadata_type=model_service.CreateTunedModelMetadata, + ) + + # Done; return the response. + return response + + async def update_tuned_model( + self, + request: Optional[Union[model_service.UpdateTunedModelRequest, dict]] = None, + *, + tuned_model: Optional[gag_tuned_model.TunedModel] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gag_tuned_model.TunedModel: + r"""Updates a tuned model. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_update_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateTunedModelRequest( + ) + + # Make the request + response = await client.update_tuned_model(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.UpdateTunedModelRequest, dict]]): + The request object. Request to update a TunedModel. + tuned_model (:class:`google.ai.generativelanguage_v1alpha.types.TunedModel`): + Required. The tuned model to update. + This corresponds to the ``tuned_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Optional. The list of fields to + update. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.TunedModel: + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tuned_model, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.UpdateTunedModelRequest): + request = model_service.UpdateTunedModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if tuned_model is not None: + request.tuned_model = tuned_model + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_tuned_model + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("tuned_model.name", request.tuned_model.name),) + ), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_tuned_model( + self, + request: Optional[Union[model_service.DeleteTunedModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Deletes a tuned model. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_delete_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteTunedModelRequest( + name="name_value", + ) + + # Make the request + await client.delete_tuned_model(request=request) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.DeleteTunedModelRequest, dict]]): + The request object. Request to delete a TunedModel. + name (:class:`str`): + Required. The resource name of the model. Format: + ``tunedModels/my-model-id`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.DeleteTunedModelRequest): + request = model_service.DeleteTunedModelRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_tuned_model + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "ModelServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("ModelServiceAsyncClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/client.py new file mode 100644 index 000000000000..41e770b7d4ba --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/client.py @@ -0,0 +1,1646 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import os +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) +import warnings + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + +from google.api_core import operation # type: ignore +from google.api_core import operation_async # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.model_service import pagers +from google.ai.generativelanguage_v1alpha.types import tuned_model as gag_tuned_model +from google.ai.generativelanguage_v1alpha.types import model, model_service +from google.ai.generativelanguage_v1alpha.types import tuned_model + +from .transports.base import DEFAULT_CLIENT_INFO, ModelServiceTransport +from .transports.grpc import ModelServiceGrpcTransport +from .transports.grpc_asyncio import ModelServiceGrpcAsyncIOTransport +from .transports.rest import ModelServiceRestTransport + + +class ModelServiceClientMeta(type): + """Metaclass for the ModelService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] + _transport_registry["grpc"] = ModelServiceGrpcTransport + _transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport + _transport_registry["rest"] = ModelServiceRestTransport + + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[ModelServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class ModelServiceClient(metaclass=ModelServiceClientMeta): + """Provides methods for getting metadata information about + Generative Models. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + _DEFAULT_ENDPOINT_TEMPLATE = "generativelanguage.{UNIVERSE_DOMAIN}" + _DEFAULT_UNIVERSE = "googleapis.com" + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + ModelServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> ModelServiceTransport: + """Returns the transport used by the client instance. + + Returns: + ModelServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def model_path( + model: str, + ) -> str: + """Returns a fully-qualified model string.""" + return "models/{model}".format( + model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parses a model path into its component segments.""" + m = re.match(r"^models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def tuned_model_path( + tuned_model: str, + ) -> str: + """Returns a fully-qualified tuned_model string.""" + return "tunedModels/{tuned_model}".format( + tuned_model=tuned_model, + ) + + @staticmethod + def parse_tuned_model_path(path: str) -> Dict[str, str]: + """Parses a tuned_model path into its component segments.""" + m = re.match(r"^tunedModels/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path( + billing_account: str, + ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path( + folder: str, + ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format( + folder=folder, + ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path( + organization: str, + ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format( + organization=organization, + ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path( + project: str, + ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format( + project=project, + ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path( + project: str, + location: str, + ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Deprecated. Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + + warnings.warn( + "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.", + DeprecationWarning, + ) + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + @staticmethod + def _read_environment_variables(): + """Returns the environment variables used by the client. + + Returns: + Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE, + GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables. + + Raises: + ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not + any of ["true", "false"]. + google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT + is not any of ["auto", "never", "always"]. + """ + use_client_cert = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() + universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + + @staticmethod + def _get_client_cert_source(provided_cert_source, use_cert_flag): + """Return the client cert source to be used by the client. + + Args: + provided_cert_source (bytes): The client certificate source provided. + use_cert_flag (bool): A flag indicating whether to use the client certificate. + + Returns: + bytes or None: The client cert source to be used by the client. + """ + client_cert_source = None + if use_cert_flag: + if provided_cert_source: + client_cert_source = provided_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + return client_cert_source + + @staticmethod + def _get_api_endpoint( + api_override, client_cert_source, universe_domain, use_mtls_endpoint + ): + """Return the API endpoint used by the client. + + Args: + api_override (str): The API endpoint override. If specified, this is always + the return value of this function and the other arguments are not used. + client_cert_source (bytes): The client certificate source used by the client. + universe_domain (str): The universe domain used by the client. + use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters. + Possible values are "always", "auto", or "never". + + Returns: + str: The API endpoint to be used by the client. + """ + if api_override is not None: + api_endpoint = api_override + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + _default_universe = ModelServiceClient._DEFAULT_UNIVERSE + if universe_domain != _default_universe: + raise MutualTLSChannelError( + f"mTLS is not supported in any universe other than {_default_universe}." + ) + api_endpoint = ModelServiceClient.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = ModelServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=universe_domain + ) + return api_endpoint + + @staticmethod + def _get_universe_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] + ) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = ModelServiceClient._DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise ValueError("Universe Domain cannot be an empty string.") + return universe_domain + + def _validate_universe_domain(self): + """Validates client's and credentials' universe domains are consistent. + + Returns: + bool: True iff the configured universe domain is valid. + + Raises: + ValueError: If the configured universe domain is not valid. + """ + + # NOTE (b/349488459): universe validation is disabled until further notice. + return True + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used by the client instance. + """ + return self._universe_domain + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[str, ModelServiceTransport, Callable[..., ModelServiceTransport]] + ] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the model service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,ModelServiceTransport,Callable[..., ModelServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the ModelServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that the ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client_options = client_options + if isinstance(self._client_options, dict): + self._client_options = client_options_lib.from_dict(self._client_options) + if self._client_options is None: + self._client_options = client_options_lib.ClientOptions() + self._client_options = cast( + client_options_lib.ClientOptions, self._client_options + ) + + universe_domain_opt = getattr(self._client_options, "universe_domain", None) + + ( + self._use_client_cert, + self._use_mtls_endpoint, + self._universe_domain_env, + ) = ModelServiceClient._read_environment_variables() + self._client_cert_source = ModelServiceClient._get_client_cert_source( + self._client_options.client_cert_source, self._use_client_cert + ) + self._universe_domain = ModelServiceClient._get_universe_domain( + universe_domain_opt, self._universe_domain_env + ) + self._api_endpoint = None # updated below, depending on `transport` + + # Initialize the universe domain validation. + self._is_universe_domain_valid = False + + if CLIENT_LOGGING_SUPPORTED: # pragma: NO COVER + # Setup logging. + client_logging.initialize_logging() + + api_key_value = getattr(self._client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + transport_provided = isinstance(transport, ModelServiceTransport) + if transport_provided: + # transport is a ModelServiceTransport instance. + if credentials or self._client_options.credentials_file or api_key_value: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if self._client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = cast(ModelServiceTransport, transport) + self._api_endpoint = self._transport.host + + self._api_endpoint = self._api_endpoint or ModelServiceClient._get_api_endpoint( + self._client_options.api_endpoint, + self._client_cert_source, + self._universe_domain, + self._use_mtls_endpoint, + ) + + if not transport_provided: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + + transport_init: Union[ + Type[ModelServiceTransport], Callable[..., ModelServiceTransport] + ] = ( + ModelServiceClient.get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., ModelServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( + credentials=credentials, + credentials_file=self._client_options.credentials_file, + host=self._api_endpoint, + scopes=self._client_options.scopes, + client_cert_source_for_mtls=self._client_cert_source, + quota_project_id=self._client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=self._client_options.api_audience, + ) + + if "async" not in str(self._transport): + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.ModelServiceClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "universeDomain": getattr( + self._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._transport._credentials).__module__}.{type(self._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "credentialsType": None, + }, + ) + + def get_model( + self, + request: Optional[Union[model_service.GetModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> model.Model: + r"""Gets information about a specific ``Model`` such as its version + number, token limits, + `parameters `__ + and other metadata. Refer to the `Gemini models + guide `__ + for detailed model information. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_get_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetModelRequest( + name="name_value", + ) + + # Make the request + response = client.get_model(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.GetModelRequest, dict]): + The request object. Request for getting information about + a specific Model. + name (str): + Required. The resource name of the model. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Model: + Information about a Generative + Language Model. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.GetModelRequest): + request = model_service.GetModelRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_models( + self, + request: Optional[Union[model_service.ListModelsRequest, dict]] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListModelsPager: + r"""Lists the + ```Model``\ s `__ + available through the Gemini API. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_list_models(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListModelsRequest( + ) + + # Make the request + page_result = client.list_models(request=request) + + # Handle the response + for response in page_result: + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.ListModelsRequest, dict]): + The request object. Request for listing all Models. + page_size (int): + The maximum number of ``Models`` to return (per page). + + If unspecified, 50 models will be returned per page. + This method returns at most 1000 models per page, even + if you pass a larger page_size. + + This corresponds to the ``page_size`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + page_token (str): + A page token, received from a previous ``ListModels`` + call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListModels`` must match the call that provided the + page token. + + This corresponds to the ``page_token`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.model_service.pagers.ListModelsPager: + Response from ListModel containing a paginated list of + Models. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([page_size, page_token]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ListModelsRequest): + request = model_service.ListModelsRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if page_size is not None: + request.page_size = page_size + if page_token is not None: + request.page_token = page_token + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_models] + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListModelsPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_tuned_model( + self, + request: Optional[Union[model_service.GetTunedModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> tuned_model.TunedModel: + r"""Gets information about a specific TunedModel. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_get_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetTunedModelRequest( + name="name_value", + ) + + # Make the request + response = client.get_tuned_model(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.GetTunedModelRequest, dict]): + The request object. Request for getting information about + a specific Model. + name (str): + Required. The resource name of the model. + + Format: ``tunedModels/my-model-id`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.TunedModel: + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.GetTunedModelRequest): + request = model_service.GetTunedModelRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_tuned_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_tuned_models( + self, + request: Optional[Union[model_service.ListTunedModelsRequest, dict]] = None, + *, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListTunedModelsPager: + r"""Lists created tuned models. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_list_tuned_models(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListTunedModelsRequest( + ) + + # Make the request + page_result = client.list_tuned_models(request=request) + + # Handle the response + for response in page_result: + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.ListTunedModelsRequest, dict]): + The request object. Request for listing TunedModels. + page_size (int): + Optional. The maximum number of ``TunedModels`` to + return (per page). The service may return fewer tuned + models. + + If unspecified, at most 10 tuned models will be + returned. This method returns at most 1000 models per + page, even if you pass a larger page_size. + + This corresponds to the ``page_size`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + page_token (str): + Optional. A page token, received from a previous + ``ListTunedModels`` call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListTunedModels`` must match the call that provided + the page token. + + This corresponds to the ``page_token`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.model_service.pagers.ListTunedModelsPager: + Response from ListTunedModels containing a paginated + list of Models. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([page_size, page_token]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.ListTunedModelsRequest): + request = model_service.ListTunedModelsRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if page_size is not None: + request.page_size = page_size + if page_token is not None: + request.page_token = page_token + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_tuned_models] + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListTunedModelsPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def create_tuned_model( + self, + request: Optional[Union[model_service.CreateTunedModelRequest, dict]] = None, + *, + tuned_model: Optional[gag_tuned_model.TunedModel] = None, + tuned_model_id: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operation.Operation: + r"""Creates a tuned model. Check intermediate tuning progress (if + any) through the [google.longrunning.Operations] service. + + Access status and results through the Operations service. + Example: GET /v1/tunedModels/az2mb0bpw6i/operations/000-111-222 + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_create_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateTunedModelRequest( + ) + + # Make the request + operation = client.create_tuned_model(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.CreateTunedModelRequest, dict]): + The request object. Request to create a TunedModel. + tuned_model (google.ai.generativelanguage_v1alpha.types.TunedModel): + Required. The tuned model to create. + This corresponds to the ``tuned_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + tuned_model_id (str): + Optional. The unique id for the tuned model if + specified. This value should be up to 40 characters, the + first character must be a letter, the last could be a + letter or a number. The id must match the regular + expression: ``[a-z]([a-z0-9-]{0,38}[a-z0-9])?``. + + This corresponds to the ``tuned_model_id`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.api_core.operation.Operation: + An object representing a long-running operation. + + The result type for the operation will be + :class:`google.ai.generativelanguage_v1alpha.types.TunedModel` + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tuned_model, tuned_model_id]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.CreateTunedModelRequest): + request = model_service.CreateTunedModelRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if tuned_model is not None: + request.tuned_model = tuned_model + if tuned_model_id is not None: + request.tuned_model_id = tuned_model_id + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_tuned_model] + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Wrap the response in an operation future. + response = operation.from_gapic( + response, + self._transport.operations_client, + gag_tuned_model.TunedModel, + metadata_type=model_service.CreateTunedModelMetadata, + ) + + # Done; return the response. + return response + + def update_tuned_model( + self, + request: Optional[Union[model_service.UpdateTunedModelRequest, dict]] = None, + *, + tuned_model: Optional[gag_tuned_model.TunedModel] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gag_tuned_model.TunedModel: + r"""Updates a tuned model. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_update_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateTunedModelRequest( + ) + + # Make the request + response = client.update_tuned_model(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.UpdateTunedModelRequest, dict]): + The request object. Request to update a TunedModel. + tuned_model (google.ai.generativelanguage_v1alpha.types.TunedModel): + Required. The tuned model to update. + This corresponds to the ``tuned_model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Optional. The list of fields to + update. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.TunedModel: + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([tuned_model, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.UpdateTunedModelRequest): + request = model_service.UpdateTunedModelRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if tuned_model is not None: + request.tuned_model = tuned_model + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_tuned_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("tuned_model.name", request.tuned_model.name),) + ), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_tuned_model( + self, + request: Optional[Union[model_service.DeleteTunedModelRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Deletes a tuned model. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_delete_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteTunedModelRequest( + name="name_value", + ) + + # Make the request + client.delete_tuned_model(request=request) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.DeleteTunedModelRequest, dict]): + The request object. Request to delete a TunedModel. + name (str): + Required. The resource name of the model. Format: + ``tunedModels/my-model-id`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, model_service.DeleteTunedModelRequest): + request = model_service.DeleteTunedModelRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_tuned_model] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def __enter__(self) -> "ModelServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("ModelServiceClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/pagers.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/pagers.py new file mode 100644 index 000000000000..c2266d4db067 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/pagers.py @@ -0,0 +1,353 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Iterator, + Optional, + Sequence, + Tuple, + Union, +) + +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + +from google.ai.generativelanguage_v1alpha.types import model, model_service, tuned_model + + +class ListModelsPager: + """A pager for iterating through ``list_models`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListModelsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``models`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListModels`` requests and continue to iterate + through the ``models`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListModelsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., model_service.ListModelsResponse], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListModelsRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListModelsResponse): + The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = model_service.ListModelsRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[model_service.ListModelsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __iter__(self) -> Iterator[model.Model]: + for page in self.pages: + yield from page.models + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListModelsAsyncPager: + """A pager for iterating through ``list_models`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListModelsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``models`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListModels`` requests and continue to iterate + through the ``models`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListModelsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[model_service.ListModelsResponse]], + request: model_service.ListModelsRequest, + response: model_service.ListModelsResponse, + *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListModelsRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListModelsResponse): + The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = model_service.ListModelsRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[model_service.ListModelsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __aiter__(self) -> AsyncIterator[model.Model]: + async def async_generator(): + async for page in self.pages: + for response in page.models: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListTunedModelsPager: + """A pager for iterating through ``list_tuned_models`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListTunedModelsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``tuned_models`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListTunedModels`` requests and continue to iterate + through the ``tuned_models`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListTunedModelsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., model_service.ListTunedModelsResponse], + request: model_service.ListTunedModelsRequest, + response: model_service.ListTunedModelsResponse, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListTunedModelsRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListTunedModelsResponse): + The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = model_service.ListTunedModelsRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[model_service.ListTunedModelsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __iter__(self) -> Iterator[tuned_model.TunedModel]: + for page in self.pages: + yield from page.tuned_models + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListTunedModelsAsyncPager: + """A pager for iterating through ``list_tuned_models`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListTunedModelsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``tuned_models`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListTunedModels`` requests and continue to iterate + through the ``tuned_models`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListTunedModelsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[model_service.ListTunedModelsResponse]], + request: model_service.ListTunedModelsRequest, + response: model_service.ListTunedModelsResponse, + *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListTunedModelsRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListTunedModelsResponse): + The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = model_service.ListTunedModelsRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[model_service.ListTunedModelsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __aiter__(self) -> AsyncIterator[tuned_model.TunedModel]: + async def async_generator(): + async for page in self.pages: + for response in page.tuned_models: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/README.rst b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/README.rst new file mode 100644 index 000000000000..05dddc4c34ad --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/README.rst @@ -0,0 +1,9 @@ + +transport inheritance structure +_______________________________ + +`ModelServiceTransport` is the ABC for all transports. +- public child `ModelServiceGrpcTransport` for sync gRPC transport (defined in `grpc.py`). +- public child `ModelServiceGrpcAsyncIOTransport` for async gRPC transport (defined in `grpc_asyncio.py`). +- private child `_BaseModelServiceRestTransport` for base REST transport with inner classes `_BaseMETHOD` (defined in `rest_base.py`). +- public child `ModelServiceRestTransport` for sync REST transport with inner classes `METHOD` derived from the parent's corresponding `_BaseMETHOD` classes (defined in `rest.py`). diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/__init__.py new file mode 100644 index 000000000000..64fd23752c6d --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/__init__.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import ModelServiceTransport +from .grpc import ModelServiceGrpcTransport +from .grpc_asyncio import ModelServiceGrpcAsyncIOTransport +from .rest import ModelServiceRestInterceptor, ModelServiceRestTransport + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[ModelServiceTransport]] +_transport_registry["grpc"] = ModelServiceGrpcTransport +_transport_registry["grpc_asyncio"] = ModelServiceGrpcAsyncIOTransport +_transport_registry["rest"] = ModelServiceRestTransport + +__all__ = ( + "ModelServiceTransport", + "ModelServiceGrpcTransport", + "ModelServiceGrpcAsyncIOTransport", + "ModelServiceRestTransport", + "ModelServiceRestInterceptor", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/base.py new file mode 100644 index 000000000000..234dba88faef --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/base.py @@ -0,0 +1,290 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, operations_v1 +from google.api_core import retry as retries +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account # type: ignore +from google.protobuf import empty_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version +from google.ai.generativelanguage_v1alpha.types import tuned_model as gag_tuned_model +from google.ai.generativelanguage_v1alpha.types import model, model_service +from google.ai.generativelanguage_v1alpha.types import tuned_model + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +class ModelServiceTransport(abc.ABC): + """Abstract transport class for ModelService.""" + + AUTH_SCOPES = () + + DEFAULT_HOST: str = "generativelanguage.googleapis.com" + + def __init__( + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + if not hasattr(self, "_ignore_credentials"): + self._ignore_credentials: bool = False + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) + elif credentials is None and not self._ignore_credentials: + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + @property + def host(self): + return self._host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.get_model: gapic_v1.method.wrap_method( + self.get_model, + default_timeout=None, + client_info=client_info, + ), + self.list_models: gapic_v1.method.wrap_method( + self.list_models, + default_timeout=None, + client_info=client_info, + ), + self.get_tuned_model: gapic_v1.method.wrap_method( + self.get_tuned_model, + default_timeout=None, + client_info=client_info, + ), + self.list_tuned_models: gapic_v1.method.wrap_method( + self.list_tuned_models, + default_timeout=None, + client_info=client_info, + ), + self.create_tuned_model: gapic_v1.method.wrap_method( + self.create_tuned_model, + default_timeout=None, + client_info=client_info, + ), + self.update_tuned_model: gapic_v1.method.wrap_method( + self.update_tuned_model, + default_timeout=None, + client_info=client_info, + ), + self.delete_tuned_model: gapic_v1.method.wrap_method( + self.delete_tuned_model, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: gapic_v1.method.wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: gapic_v1.method.wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def operations_client(self): + """Return the client designed to process long-running operations.""" + raise NotImplementedError() + + @property + def get_model( + self, + ) -> Callable[ + [model_service.GetModelRequest], Union[model.Model, Awaitable[model.Model]] + ]: + raise NotImplementedError() + + @property + def list_models( + self, + ) -> Callable[ + [model_service.ListModelsRequest], + Union[ + model_service.ListModelsResponse, + Awaitable[model_service.ListModelsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_tuned_model( + self, + ) -> Callable[ + [model_service.GetTunedModelRequest], + Union[tuned_model.TunedModel, Awaitable[tuned_model.TunedModel]], + ]: + raise NotImplementedError() + + @property + def list_tuned_models( + self, + ) -> Callable[ + [model_service.ListTunedModelsRequest], + Union[ + model_service.ListTunedModelsResponse, + Awaitable[model_service.ListTunedModelsResponse], + ], + ]: + raise NotImplementedError() + + @property + def create_tuned_model( + self, + ) -> Callable[ + [model_service.CreateTunedModelRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def update_tuned_model( + self, + ) -> Callable[ + [model_service.UpdateTunedModelRequest], + Union[gag_tuned_model.TunedModel, Awaitable[gag_tuned_model.TunedModel]], + ]: + raise NotImplementedError() + + @property + def delete_tuned_model( + self, + ) -> Callable[ + [model_service.DeleteTunedModelRequest], + Union[empty_pb2.Empty, Awaitable[empty_pb2.Empty]], + ]: + raise NotImplementedError() + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], + Union[ + operations_pb2.ListOperationsResponse, + Awaitable[operations_pb2.ListOperationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_operation( + self, + ) -> Callable[ + [operations_pb2.GetOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ("ModelServiceTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/grpc.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/grpc.py new file mode 100644 index 000000000000..501f86f0dde2 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/grpc.py @@ -0,0 +1,582 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import logging as std_logging +import pickle +from typing import Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import gapic_v1, grpc_helpers, operations_v1 +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import tuned_model as gag_tuned_model +from google.ai.generativelanguage_v1alpha.types import model, model_service +from google.ai.generativelanguage_v1alpha.types import tuned_model + +from .base import DEFAULT_CLIENT_INFO, ModelServiceTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientInterceptor(grpc.UnaryUnaryClientInterceptor): # pragma: NO COVER + def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": client_call_details.method, + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + + response = continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = response.result() + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response for {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": client_call_details.method, + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class ModelServiceGrpcTransport(ModelServiceTransport): + """gRPC backend transport for ModelService. + + Provides methods for getting metadata information about + Generative Models. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client: Optional[operations_v1.OperationsClient] = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, grpc.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientInterceptor() + self._logged_channel = grpc.intercept_channel( + self._grpc_channel, self._interceptor + ) + + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service.""" + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Quick check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsClient( + self._logged_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def get_model(self) -> Callable[[model_service.GetModelRequest], model.Model]: + r"""Return a callable for the get model method over gRPC. + + Gets information about a specific ``Model`` such as its version + number, token limits, + `parameters `__ + and other metadata. Refer to the `Gemini models + guide `__ + for detailed model information. + + Returns: + Callable[[~.GetModelRequest], + ~.Model]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_model" not in self._stubs: + self._stubs["get_model"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.ModelService/GetModel", + request_serializer=model_service.GetModelRequest.serialize, + response_deserializer=model.Model.deserialize, + ) + return self._stubs["get_model"] + + @property + def list_models( + self, + ) -> Callable[[model_service.ListModelsRequest], model_service.ListModelsResponse]: + r"""Return a callable for the list models method over gRPC. + + Lists the + ```Model``\ s `__ + available through the Gemini API. + + Returns: + Callable[[~.ListModelsRequest], + ~.ListModelsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_models" not in self._stubs: + self._stubs["list_models"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.ModelService/ListModels", + request_serializer=model_service.ListModelsRequest.serialize, + response_deserializer=model_service.ListModelsResponse.deserialize, + ) + return self._stubs["list_models"] + + @property + def get_tuned_model( + self, + ) -> Callable[[model_service.GetTunedModelRequest], tuned_model.TunedModel]: + r"""Return a callable for the get tuned model method over gRPC. + + Gets information about a specific TunedModel. + + Returns: + Callable[[~.GetTunedModelRequest], + ~.TunedModel]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_tuned_model" not in self._stubs: + self._stubs["get_tuned_model"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.ModelService/GetTunedModel", + request_serializer=model_service.GetTunedModelRequest.serialize, + response_deserializer=tuned_model.TunedModel.deserialize, + ) + return self._stubs["get_tuned_model"] + + @property + def list_tuned_models( + self, + ) -> Callable[ + [model_service.ListTunedModelsRequest], model_service.ListTunedModelsResponse + ]: + r"""Return a callable for the list tuned models method over gRPC. + + Lists created tuned models. + + Returns: + Callable[[~.ListTunedModelsRequest], + ~.ListTunedModelsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_tuned_models" not in self._stubs: + self._stubs["list_tuned_models"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.ModelService/ListTunedModels", + request_serializer=model_service.ListTunedModelsRequest.serialize, + response_deserializer=model_service.ListTunedModelsResponse.deserialize, + ) + return self._stubs["list_tuned_models"] + + @property + def create_tuned_model( + self, + ) -> Callable[[model_service.CreateTunedModelRequest], operations_pb2.Operation]: + r"""Return a callable for the create tuned model method over gRPC. + + Creates a tuned model. Check intermediate tuning progress (if + any) through the [google.longrunning.Operations] service. + + Access status and results through the Operations service. + Example: GET /v1/tunedModels/az2mb0bpw6i/operations/000-111-222 + + Returns: + Callable[[~.CreateTunedModelRequest], + ~.Operation]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_tuned_model" not in self._stubs: + self._stubs["create_tuned_model"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.ModelService/CreateTunedModel", + request_serializer=model_service.CreateTunedModelRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["create_tuned_model"] + + @property + def update_tuned_model( + self, + ) -> Callable[[model_service.UpdateTunedModelRequest], gag_tuned_model.TunedModel]: + r"""Return a callable for the update tuned model method over gRPC. + + Updates a tuned model. + + Returns: + Callable[[~.UpdateTunedModelRequest], + ~.TunedModel]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_tuned_model" not in self._stubs: + self._stubs["update_tuned_model"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.ModelService/UpdateTunedModel", + request_serializer=model_service.UpdateTunedModelRequest.serialize, + response_deserializer=gag_tuned_model.TunedModel.deserialize, + ) + return self._stubs["update_tuned_model"] + + @property + def delete_tuned_model( + self, + ) -> Callable[[model_service.DeleteTunedModelRequest], empty_pb2.Empty]: + r"""Return a callable for the delete tuned model method over gRPC. + + Deletes a tuned model. + + Returns: + Callable[[~.DeleteTunedModelRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_tuned_model" not in self._stubs: + self._stubs["delete_tuned_model"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.ModelService/DeleteTunedModel", + request_serializer=model_service.DeleteTunedModelRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["delete_tuned_model"] + + def close(self): + self._logged_channel.close() + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ("ModelServiceGrpcTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/grpc_asyncio.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..a97df75a5c45 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/grpc_asyncio.py @@ -0,0 +1,655 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import json +import logging as std_logging +import pickle +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, grpc_helpers_async, operations_v1 +from google.api_core import retry_async as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +from grpc.experimental import aio # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import tuned_model as gag_tuned_model +from google.ai.generativelanguage_v1alpha.types import model, model_service +from google.ai.generativelanguage_v1alpha.types import tuned_model + +from .base import DEFAULT_CLIENT_INFO, ModelServiceTransport +from .grpc import ModelServiceGrpcTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientAIOInterceptor( + grpc.aio.UnaryUnaryClientInterceptor +): # pragma: NO COVER + async def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": str(client_call_details.method), + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + response = await continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = await response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = await response + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response to rpc {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": str(client_call_details.method), + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class ModelServiceGrpcAsyncIOTransport(ModelServiceTransport): + """gRPC AsyncIO backend transport for ModelService. + + Provides methods for getting metadata information about + Generative Models. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + self._operations_client: Optional[operations_v1.OperationsAsyncClient] = None + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, aio.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientAIOInterceptor() + self._grpc_channel._unary_unary_interceptors.append(self._interceptor) + self._logged_channel = self._grpc_channel + self._wrap_with_kind = ( + "kind" in inspect.signature(gapic_v1.method_async.wrap_method).parameters + ) + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def operations_client(self) -> operations_v1.OperationsAsyncClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Quick check: Only create a new client if we do not already have one. + if self._operations_client is None: + self._operations_client = operations_v1.OperationsAsyncClient( + self._logged_channel + ) + + # Return the client from cache. + return self._operations_client + + @property + def get_model( + self, + ) -> Callable[[model_service.GetModelRequest], Awaitable[model.Model]]: + r"""Return a callable for the get model method over gRPC. + + Gets information about a specific ``Model`` such as its version + number, token limits, + `parameters `__ + and other metadata. Refer to the `Gemini models + guide `__ + for detailed model information. + + Returns: + Callable[[~.GetModelRequest], + Awaitable[~.Model]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_model" not in self._stubs: + self._stubs["get_model"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.ModelService/GetModel", + request_serializer=model_service.GetModelRequest.serialize, + response_deserializer=model.Model.deserialize, + ) + return self._stubs["get_model"] + + @property + def list_models( + self, + ) -> Callable[ + [model_service.ListModelsRequest], Awaitable[model_service.ListModelsResponse] + ]: + r"""Return a callable for the list models method over gRPC. + + Lists the + ```Model``\ s `__ + available through the Gemini API. + + Returns: + Callable[[~.ListModelsRequest], + Awaitable[~.ListModelsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_models" not in self._stubs: + self._stubs["list_models"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.ModelService/ListModels", + request_serializer=model_service.ListModelsRequest.serialize, + response_deserializer=model_service.ListModelsResponse.deserialize, + ) + return self._stubs["list_models"] + + @property + def get_tuned_model( + self, + ) -> Callable[ + [model_service.GetTunedModelRequest], Awaitable[tuned_model.TunedModel] + ]: + r"""Return a callable for the get tuned model method over gRPC. + + Gets information about a specific TunedModel. + + Returns: + Callable[[~.GetTunedModelRequest], + Awaitable[~.TunedModel]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_tuned_model" not in self._stubs: + self._stubs["get_tuned_model"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.ModelService/GetTunedModel", + request_serializer=model_service.GetTunedModelRequest.serialize, + response_deserializer=tuned_model.TunedModel.deserialize, + ) + return self._stubs["get_tuned_model"] + + @property + def list_tuned_models( + self, + ) -> Callable[ + [model_service.ListTunedModelsRequest], + Awaitable[model_service.ListTunedModelsResponse], + ]: + r"""Return a callable for the list tuned models method over gRPC. + + Lists created tuned models. + + Returns: + Callable[[~.ListTunedModelsRequest], + Awaitable[~.ListTunedModelsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_tuned_models" not in self._stubs: + self._stubs["list_tuned_models"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.ModelService/ListTunedModels", + request_serializer=model_service.ListTunedModelsRequest.serialize, + response_deserializer=model_service.ListTunedModelsResponse.deserialize, + ) + return self._stubs["list_tuned_models"] + + @property + def create_tuned_model( + self, + ) -> Callable[ + [model_service.CreateTunedModelRequest], Awaitable[operations_pb2.Operation] + ]: + r"""Return a callable for the create tuned model method over gRPC. + + Creates a tuned model. Check intermediate tuning progress (if + any) through the [google.longrunning.Operations] service. + + Access status and results through the Operations service. + Example: GET /v1/tunedModels/az2mb0bpw6i/operations/000-111-222 + + Returns: + Callable[[~.CreateTunedModelRequest], + Awaitable[~.Operation]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_tuned_model" not in self._stubs: + self._stubs["create_tuned_model"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.ModelService/CreateTunedModel", + request_serializer=model_service.CreateTunedModelRequest.serialize, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["create_tuned_model"] + + @property + def update_tuned_model( + self, + ) -> Callable[ + [model_service.UpdateTunedModelRequest], Awaitable[gag_tuned_model.TunedModel] + ]: + r"""Return a callable for the update tuned model method over gRPC. + + Updates a tuned model. + + Returns: + Callable[[~.UpdateTunedModelRequest], + Awaitable[~.TunedModel]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_tuned_model" not in self._stubs: + self._stubs["update_tuned_model"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.ModelService/UpdateTunedModel", + request_serializer=model_service.UpdateTunedModelRequest.serialize, + response_deserializer=gag_tuned_model.TunedModel.deserialize, + ) + return self._stubs["update_tuned_model"] + + @property + def delete_tuned_model( + self, + ) -> Callable[[model_service.DeleteTunedModelRequest], Awaitable[empty_pb2.Empty]]: + r"""Return a callable for the delete tuned model method over gRPC. + + Deletes a tuned model. + + Returns: + Callable[[~.DeleteTunedModelRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_tuned_model" not in self._stubs: + self._stubs["delete_tuned_model"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.ModelService/DeleteTunedModel", + request_serializer=model_service.DeleteTunedModelRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["delete_tuned_model"] + + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.get_model: self._wrap_method( + self.get_model, + default_timeout=None, + client_info=client_info, + ), + self.list_models: self._wrap_method( + self.list_models, + default_timeout=None, + client_info=client_info, + ), + self.get_tuned_model: self._wrap_method( + self.get_tuned_model, + default_timeout=None, + client_info=client_info, + ), + self.list_tuned_models: self._wrap_method( + self.list_tuned_models, + default_timeout=None, + client_info=client_info, + ), + self.create_tuned_model: self._wrap_method( + self.create_tuned_model, + default_timeout=None, + client_info=client_info, + ), + self.update_tuned_model: self._wrap_method( + self.update_tuned_model, + default_timeout=None, + client_info=client_info, + ), + self.delete_tuned_model: self._wrap_method( + self.delete_tuned_model, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: self._wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: self._wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def _wrap_method(self, func, *args, **kwargs): + if self._wrap_with_kind: # pragma: NO COVER + kwargs["kind"] = self.kind + return gapic_v1.method_async.wrap_method(func, *args, **kwargs) + + def close(self): + return self._logged_channel.close() + + @property + def kind(self) -> str: + return "grpc_asyncio" + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + +__all__ = ("ModelServiceGrpcAsyncIOTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/rest.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/rest.py new file mode 100644 index 000000000000..8f3079a1e97d --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/rest.py @@ -0,0 +1,1819 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +import json # type: ignore +import logging +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import gapic_v1, operations_v1, rest_helpers, rest_streaming +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.requests import AuthorizedSession # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import json_format +from requests import __version__ as requests_version + +from google.ai.generativelanguage_v1alpha.types import tuned_model as gag_tuned_model +from google.ai.generativelanguage_v1alpha.types import model, model_service +from google.ai.generativelanguage_v1alpha.types import tuned_model + +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from .rest_base import _BaseModelServiceRestTransport + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = logging.getLogger(__name__) + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=f"requests@{requests_version}", +) + + +class ModelServiceRestInterceptor: + """Interceptor for ModelService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the ModelServiceRestTransport. + + .. code-block:: python + class MyCustomModelServiceInterceptor(ModelServiceRestInterceptor): + def pre_create_tuned_model(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_create_tuned_model(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_delete_tuned_model(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def pre_get_model(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_model(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_get_tuned_model(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_tuned_model(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_models(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_models(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_tuned_models(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_tuned_models(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_update_tuned_model(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_update_tuned_model(self, response): + logging.log(f"Received response: {response}") + return response + + transport = ModelServiceRestTransport(interceptor=MyCustomModelServiceInterceptor()) + client = ModelServiceClient(transport=transport) + + + """ + + def pre_create_tuned_model( + self, + request: model_service.CreateTunedModelRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + model_service.CreateTunedModelRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for create_tuned_model + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_create_tuned_model( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for create_tuned_model + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + + def pre_delete_tuned_model( + self, + request: model_service.DeleteTunedModelRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + model_service.DeleteTunedModelRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for delete_tuned_model + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def pre_get_model( + self, + request: model_service.GetModelRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[model_service.GetModelRequest, Sequence[Tuple[str, Union[str, bytes]]]]: + """Pre-rpc interceptor for get_model + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_get_model(self, response: model.Model) -> model.Model: + """Post-rpc interceptor for get_model + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + + def pre_get_tuned_model( + self, + request: model_service.GetTunedModelRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + model_service.GetTunedModelRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for get_tuned_model + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_get_tuned_model( + self, response: tuned_model.TunedModel + ) -> tuned_model.TunedModel: + """Post-rpc interceptor for get_tuned_model + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + + def pre_list_models( + self, + request: model_service.ListModelsRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + model_service.ListModelsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for list_models + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_list_models( + self, response: model_service.ListModelsResponse + ) -> model_service.ListModelsResponse: + """Post-rpc interceptor for list_models + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + + def pre_list_tuned_models( + self, + request: model_service.ListTunedModelsRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + model_service.ListTunedModelsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for list_tuned_models + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_list_tuned_models( + self, response: model_service.ListTunedModelsResponse + ) -> model_service.ListTunedModelsResponse: + """Post-rpc interceptor for list_tuned_models + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + + def pre_update_tuned_model( + self, + request: model_service.UpdateTunedModelRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + model_service.UpdateTunedModelRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for update_tuned_model + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_update_tuned_model( + self, response: gag_tuned_model.TunedModel + ) -> gag_tuned_model.TunedModel: + """Post-rpc interceptor for update_tuned_model + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + + def pre_get_operation( + self, + request: operations_pb2.GetOperationRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.GetOperationRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for get_operation + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_get_operation( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for get_operation + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + + def pre_list_operations( + self, + request: operations_pb2.ListOperationsRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.ListOperationsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for list_operations + + Override in a subclass to manipulate the request or metadata + before they are sent to the ModelService server. + """ + return request, metadata + + def post_list_operations( + self, response: operations_pb2.ListOperationsResponse + ) -> operations_pb2.ListOperationsResponse: + """Post-rpc interceptor for list_operations + + Override in a subclass to manipulate the response + after it is returned by the ModelService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class ModelServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: ModelServiceRestInterceptor + + +class ModelServiceRestTransport(_BaseModelServiceRestTransport): + """REST backend synchronous transport for ModelService. + + Provides methods for getting metadata information about + Generative Models. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[ModelServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + url_scheme=url_scheme, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + self._operations_client: Optional[operations_v1.AbstractOperationsClient] = None + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or ModelServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + @property + def operations_client(self) -> operations_v1.AbstractOperationsClient: + """Create the client designed to process long-running operations. + + This property caches on the instance; repeated calls return the same + client. + """ + # Only create a new client if we do not already have one. + if self._operations_client is None: + http_options: Dict[str, List[Dict[str, str]]] = { + "google.longrunning.Operations.GetOperation": [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=generatedFiles/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*/operations/*}", + }, + ], + "google.longrunning.Operations.ListOperations": [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*}/operations", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*}/operations", + }, + ], + } + + rest_transport = operations_v1.OperationsRestTransport( + host=self._host, + # use the credentials which are saved + credentials=self._credentials, + scopes=self._scopes, + http_options=http_options, + path_prefix="v1alpha", + ) + + self._operations_client = operations_v1.AbstractOperationsClient( + transport=rest_transport + ) + + # Return the client from cache. + return self._operations_client + + class _CreateTunedModel( + _BaseModelServiceRestTransport._BaseCreateTunedModel, ModelServiceRestStub + ): + def __hash__(self): + return hash("ModelServiceRestTransport.CreateTunedModel") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: model_service.CreateTunedModelRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Call the create tuned model method over HTTP. + + Args: + request (~.model_service.CreateTunedModelRequest): + The request object. Request to create a TunedModel. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.operations_pb2.Operation: + This resource represents a + long-running operation that is the + result of a network API call. + + """ + + http_options = ( + _BaseModelServiceRestTransport._BaseCreateTunedModel._get_http_options() + ) + + request, metadata = self._interceptor.pre_create_tuned_model( + request, metadata + ) + transcoded_request = _BaseModelServiceRestTransport._BaseCreateTunedModel._get_transcoded_request( + http_options, request + ) + + body = _BaseModelServiceRestTransport._BaseCreateTunedModel._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseModelServiceRestTransport._BaseCreateTunedModel._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.ModelServiceClient.CreateTunedModel", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "CreateTunedModel", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = ModelServiceRestTransport._CreateTunedModel._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_create_tuned_model(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.ModelServiceClient.create_tuned_model", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "CreateTunedModel", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _DeleteTunedModel( + _BaseModelServiceRestTransport._BaseDeleteTunedModel, ModelServiceRestStub + ): + def __hash__(self): + return hash("ModelServiceRestTransport.DeleteTunedModel") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: model_service.DeleteTunedModelRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ): + r"""Call the delete tuned model method over HTTP. + + Args: + request (~.model_service.DeleteTunedModelRequest): + The request object. Request to delete a TunedModel. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + + http_options = ( + _BaseModelServiceRestTransport._BaseDeleteTunedModel._get_http_options() + ) + + request, metadata = self._interceptor.pre_delete_tuned_model( + request, metadata + ) + transcoded_request = _BaseModelServiceRestTransport._BaseDeleteTunedModel._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseModelServiceRestTransport._BaseDeleteTunedModel._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.ModelServiceClient.DeleteTunedModel", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "DeleteTunedModel", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = ModelServiceRestTransport._DeleteTunedModel._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + class _GetModel(_BaseModelServiceRestTransport._BaseGetModel, ModelServiceRestStub): + def __hash__(self): + return hash("ModelServiceRestTransport.GetModel") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: model_service.GetModelRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> model.Model: + r"""Call the get model method over HTTP. + + Args: + request (~.model_service.GetModelRequest): + The request object. Request for getting information about + a specific Model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.model.Model: + Information about a Generative + Language Model. + + """ + + http_options = ( + _BaseModelServiceRestTransport._BaseGetModel._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_model(request, metadata) + transcoded_request = ( + _BaseModelServiceRestTransport._BaseGetModel._get_transcoded_request( + http_options, request + ) + ) + + # Jsonify the query params + query_params = ( + _BaseModelServiceRestTransport._BaseGetModel._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.ModelServiceClient.GetModel", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "GetModel", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = ModelServiceRestTransport._GetModel._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = model.Model() + pb_resp = model.Model.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_get_model(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = model.Model.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.ModelServiceClient.get_model", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "GetModel", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _GetTunedModel( + _BaseModelServiceRestTransport._BaseGetTunedModel, ModelServiceRestStub + ): + def __hash__(self): + return hash("ModelServiceRestTransport.GetTunedModel") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: model_service.GetTunedModelRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> tuned_model.TunedModel: + r"""Call the get tuned model method over HTTP. + + Args: + request (~.model_service.GetTunedModelRequest): + The request object. Request for getting information about + a specific Model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.tuned_model.TunedModel: + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + + http_options = ( + _BaseModelServiceRestTransport._BaseGetTunedModel._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_tuned_model(request, metadata) + transcoded_request = _BaseModelServiceRestTransport._BaseGetTunedModel._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseModelServiceRestTransport._BaseGetTunedModel._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.ModelServiceClient.GetTunedModel", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "GetTunedModel", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = ModelServiceRestTransport._GetTunedModel._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = tuned_model.TunedModel() + pb_resp = tuned_model.TunedModel.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_get_tuned_model(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = tuned_model.TunedModel.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.ModelServiceClient.get_tuned_model", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "GetTunedModel", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _ListModels( + _BaseModelServiceRestTransport._BaseListModels, ModelServiceRestStub + ): + def __hash__(self): + return hash("ModelServiceRestTransport.ListModels") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: model_service.ListModelsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> model_service.ListModelsResponse: + r"""Call the list models method over HTTP. + + Args: + request (~.model_service.ListModelsRequest): + The request object. Request for listing all Models. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.model_service.ListModelsResponse: + Response from ``ListModel`` containing a paginated list + of Models. + + """ + + http_options = ( + _BaseModelServiceRestTransport._BaseListModels._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_models(request, metadata) + transcoded_request = ( + _BaseModelServiceRestTransport._BaseListModels._get_transcoded_request( + http_options, request + ) + ) + + # Jsonify the query params + query_params = ( + _BaseModelServiceRestTransport._BaseListModels._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.ModelServiceClient.ListModels", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "ListModels", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = ModelServiceRestTransport._ListModels._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = model_service.ListModelsResponse() + pb_resp = model_service.ListModelsResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_list_models(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = model_service.ListModelsResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.ModelServiceClient.list_models", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "ListModels", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _ListTunedModels( + _BaseModelServiceRestTransport._BaseListTunedModels, ModelServiceRestStub + ): + def __hash__(self): + return hash("ModelServiceRestTransport.ListTunedModels") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: model_service.ListTunedModelsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> model_service.ListTunedModelsResponse: + r"""Call the list tuned models method over HTTP. + + Args: + request (~.model_service.ListTunedModelsRequest): + The request object. Request for listing TunedModels. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.model_service.ListTunedModelsResponse: + Response from ``ListTunedModels`` containing a paginated + list of Models. + + """ + + http_options = ( + _BaseModelServiceRestTransport._BaseListTunedModels._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_tuned_models( + request, metadata + ) + transcoded_request = _BaseModelServiceRestTransport._BaseListTunedModels._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseModelServiceRestTransport._BaseListTunedModels._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.ModelServiceClient.ListTunedModels", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "ListTunedModels", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = ModelServiceRestTransport._ListTunedModels._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = model_service.ListTunedModelsResponse() + pb_resp = model_service.ListTunedModelsResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_list_tuned_models(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = model_service.ListTunedModelsResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.ModelServiceClient.list_tuned_models", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "ListTunedModels", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _UpdateTunedModel( + _BaseModelServiceRestTransport._BaseUpdateTunedModel, ModelServiceRestStub + ): + def __hash__(self): + return hash("ModelServiceRestTransport.UpdateTunedModel") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: model_service.UpdateTunedModelRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gag_tuned_model.TunedModel: + r"""Call the update tuned model method over HTTP. + + Args: + request (~.model_service.UpdateTunedModelRequest): + The request object. Request to update a TunedModel. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.gag_tuned_model.TunedModel: + A fine-tuned model created using + ModelService.CreateTunedModel. + + """ + + http_options = ( + _BaseModelServiceRestTransport._BaseUpdateTunedModel._get_http_options() + ) + + request, metadata = self._interceptor.pre_update_tuned_model( + request, metadata + ) + transcoded_request = _BaseModelServiceRestTransport._BaseUpdateTunedModel._get_transcoded_request( + http_options, request + ) + + body = _BaseModelServiceRestTransport._BaseUpdateTunedModel._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseModelServiceRestTransport._BaseUpdateTunedModel._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.ModelServiceClient.UpdateTunedModel", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "UpdateTunedModel", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = ModelServiceRestTransport._UpdateTunedModel._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = gag_tuned_model.TunedModel() + pb_resp = gag_tuned_model.TunedModel.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_update_tuned_model(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = gag_tuned_model.TunedModel.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.ModelServiceClient.update_tuned_model", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "UpdateTunedModel", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + @property + def create_tuned_model( + self, + ) -> Callable[[model_service.CreateTunedModelRequest], operations_pb2.Operation]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CreateTunedModel(self._session, self._host, self._interceptor) # type: ignore + + @property + def delete_tuned_model( + self, + ) -> Callable[[model_service.DeleteTunedModelRequest], empty_pb2.Empty]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DeleteTunedModel(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_model(self) -> Callable[[model_service.GetModelRequest], model.Model]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetModel(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_tuned_model( + self, + ) -> Callable[[model_service.GetTunedModelRequest], tuned_model.TunedModel]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetTunedModel(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_models( + self, + ) -> Callable[[model_service.ListModelsRequest], model_service.ListModelsResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListModels(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_tuned_models( + self, + ) -> Callable[ + [model_service.ListTunedModelsRequest], model_service.ListTunedModelsResponse + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListTunedModels(self._session, self._host, self._interceptor) # type: ignore + + @property + def update_tuned_model( + self, + ) -> Callable[[model_service.UpdateTunedModelRequest], gag_tuned_model.TunedModel]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._UpdateTunedModel(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_operation(self): + return self._GetOperation(self._session, self._host, self._interceptor) # type: ignore + + class _GetOperation( + _BaseModelServiceRestTransport._BaseGetOperation, ModelServiceRestStub + ): + def __hash__(self): + return hash("ModelServiceRestTransport.GetOperation") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.GetOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Call the get operation method over HTTP. + + Args: + request (operations_pb2.GetOperationRequest): + The request object for GetOperation method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.Operation: Response from GetOperation method. + """ + + http_options = ( + _BaseModelServiceRestTransport._BaseGetOperation._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_operation(request, metadata) + transcoded_request = _BaseModelServiceRestTransport._BaseGetOperation._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = ( + _BaseModelServiceRestTransport._BaseGetOperation._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.ModelServiceClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "GetOperation", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = ModelServiceRestTransport._GetOperation._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.Operation() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_get_operation(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "GetOperation", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def list_operations(self): + return self._ListOperations(self._session, self._host, self._interceptor) # type: ignore + + class _ListOperations( + _BaseModelServiceRestTransport._BaseListOperations, ModelServiceRestStub + ): + def __hash__(self): + return hash("ModelServiceRestTransport.ListOperations") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.ListOperationsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Call the list operations method over HTTP. + + Args: + request (operations_pb2.ListOperationsRequest): + The request object for ListOperations method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.ListOperationsResponse: Response from ListOperations method. + """ + + http_options = ( + _BaseModelServiceRestTransport._BaseListOperations._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_operations(request, metadata) + transcoded_request = _BaseModelServiceRestTransport._BaseListOperations._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseModelServiceRestTransport._BaseListOperations._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.ModelServiceClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "ListOperations", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = ModelServiceRestTransport._ListOperations._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.ListOperationsResponse() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_list_operations(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.ModelService", + "rpcName": "ListOperations", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("ModelServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/rest_base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/rest_base.py new file mode 100644 index 000000000000..da731c43b6d6 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/model_service/transports/rest_base.py @@ -0,0 +1,476 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json # type: ignore +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1, path_template +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import json_format + +from google.ai.generativelanguage_v1alpha.types import tuned_model as gag_tuned_model +from google.ai.generativelanguage_v1alpha.types import model, model_service +from google.ai.generativelanguage_v1alpha.types import tuned_model + +from .base import DEFAULT_CLIENT_INFO, ModelServiceTransport + + +class _BaseModelServiceRestTransport(ModelServiceTransport): + """Base REST backend transport for ModelService. + + Note: This class is not meant to be used directly. Use its sync and + async sub-classes instead. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[Any] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[Any]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + class _BaseCreateTunedModel: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/tunedModels", + "body": "tuned_model", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = model_service.CreateTunedModelRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseModelServiceRestTransport._BaseCreateTunedModel._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseDeleteTunedModel: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1alpha/{name=tunedModels/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = model_service.DeleteTunedModelRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseModelServiceRestTransport._BaseDeleteTunedModel._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetModel: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=models/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = model_service.GetModelRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseModelServiceRestTransport._BaseGetModel._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetTunedModel: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = model_service.GetTunedModelRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseModelServiceRestTransport._BaseGetTunedModel._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseListModels: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/models", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = model_service.ListModelsRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseListTunedModels: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/tunedModels", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = model_service.ListTunedModelsRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseUpdateTunedModel: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "patch", + "uri": "/v1alpha/{tuned_model.name=tunedModels/*}", + "body": "tuned_model", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = model_service.UpdateTunedModelRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseModelServiceRestTransport._BaseUpdateTunedModel._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetOperation: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=generatedFiles/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*/operations/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + class _BaseListOperations: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*}/operations", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*}/operations", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + +__all__ = ("_BaseModelServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/__init__.py new file mode 100644 index 000000000000..ff0f9217fa0e --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .async_client import PermissionServiceAsyncClient +from .client import PermissionServiceClient + +__all__ = ( + "PermissionServiceClient", + "PermissionServiceAsyncClient", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/async_client.py new file mode 100644 index 000000000000..7be8e8f26253 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/async_client.py @@ -0,0 +1,1150 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry_async as retries +from google.api_core.client_options import ClientOptions +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore + +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.permission_service import pagers +from google.ai.generativelanguage_v1alpha.types import permission as gag_permission +from google.ai.generativelanguage_v1alpha.types import permission +from google.ai.generativelanguage_v1alpha.types import permission_service + +from .client import PermissionServiceClient +from .transports.base import DEFAULT_CLIENT_INFO, PermissionServiceTransport +from .transports.grpc_asyncio import PermissionServiceGrpcAsyncIOTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class PermissionServiceAsyncClient: + """Provides methods for managing permissions to PaLM API + resources. + """ + + _client: PermissionServiceClient + + # Copy defaults from the synchronous client for use here. + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = PermissionServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = PermissionServiceClient.DEFAULT_MTLS_ENDPOINT + _DEFAULT_ENDPOINT_TEMPLATE = PermissionServiceClient._DEFAULT_ENDPOINT_TEMPLATE + _DEFAULT_UNIVERSE = PermissionServiceClient._DEFAULT_UNIVERSE + + permission_path = staticmethod(PermissionServiceClient.permission_path) + parse_permission_path = staticmethod(PermissionServiceClient.parse_permission_path) + common_billing_account_path = staticmethod( + PermissionServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + PermissionServiceClient.parse_common_billing_account_path + ) + common_folder_path = staticmethod(PermissionServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + PermissionServiceClient.parse_common_folder_path + ) + common_organization_path = staticmethod( + PermissionServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + PermissionServiceClient.parse_common_organization_path + ) + common_project_path = staticmethod(PermissionServiceClient.common_project_path) + parse_common_project_path = staticmethod( + PermissionServiceClient.parse_common_project_path + ) + common_location_path = staticmethod(PermissionServiceClient.common_location_path) + parse_common_location_path = staticmethod( + PermissionServiceClient.parse_common_location_path + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PermissionServiceAsyncClient: The constructed client. + """ + return PermissionServiceClient.from_service_account_info.__func__(PermissionServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PermissionServiceAsyncClient: The constructed client. + """ + return PermissionServiceClient.from_service_account_file.__func__(PermissionServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return PermissionServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> PermissionServiceTransport: + """Returns the transport used by the client instance. + + Returns: + PermissionServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._client._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used + by the client instance. + """ + return self._client._universe_domain + + get_transport_class = PermissionServiceClient.get_transport_class + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[ + str, + PermissionServiceTransport, + Callable[..., PermissionServiceTransport], + ] + ] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the permission service async client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,PermissionServiceTransport,Callable[..., PermissionServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PermissionServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = PermissionServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.PermissionServiceAsyncClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "universeDomain": getattr( + self._client._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._client._transport._credentials).__module__}.{type(self._client._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._client._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "credentialsType": None, + }, + ) + + async def create_permission( + self, + request: Optional[ + Union[permission_service.CreatePermissionRequest, dict] + ] = None, + *, + parent: Optional[str] = None, + permission: Optional[gag_permission.Permission] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gag_permission.Permission: + r"""Create a permission to a specific resource. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_create_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreatePermissionRequest( + parent="parent_value", + ) + + # Make the request + response = await client.create_permission(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.CreatePermissionRequest, dict]]): + The request object. Request to create a ``Permission``. + parent (:class:`str`): + Required. The parent resource of the ``Permission``. + Formats: ``tunedModels/{tuned_model}`` + ``corpora/{corpus}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + permission (:class:`google.ai.generativelanguage_v1alpha.types.Permission`): + Required. The permission to create. + This corresponds to the ``permission`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, corpus). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model, corpus) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, permission]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, permission_service.CreatePermissionRequest): + request = permission_service.CreatePermissionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + if permission is not None: + request.permission = permission + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_permission + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_permission( + self, + request: Optional[Union[permission_service.GetPermissionRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> permission.Permission: + r"""Gets information about a specific Permission. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_get_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetPermissionRequest( + name="name_value", + ) + + # Make the request + response = await client.get_permission(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.GetPermissionRequest, dict]]): + The request object. Request for getting information about a specific + ``Permission``. + name (:class:`str`): + Required. The resource name of the permission. + + Formats: + ``tunedModels/{tuned_model}/permissions/{permission}`` + ``corpora/{corpus}/permissions/{permission}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, corpus). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model, corpus) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, permission_service.GetPermissionRequest): + request = permission_service.GetPermissionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_permission + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_permissions( + self, + request: Optional[ + Union[permission_service.ListPermissionsRequest, dict] + ] = None, + *, + parent: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListPermissionsAsyncPager: + r"""Lists permissions for the specific resource. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_list_permissions(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListPermissionsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_permissions(request=request) + + # Handle the response + async for response in page_result: + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.ListPermissionsRequest, dict]]): + The request object. Request for listing permissions. + parent (:class:`str`): + Required. The parent resource of the permissions. + Formats: ``tunedModels/{tuned_model}`` + ``corpora/{corpus}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.permission_service.pagers.ListPermissionsAsyncPager: + Response from ListPermissions containing a paginated list of + permissions. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, permission_service.ListPermissionsRequest): + request = permission_service.ListPermissionsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_permissions + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListPermissionsAsyncPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_permission( + self, + request: Optional[ + Union[permission_service.UpdatePermissionRequest, dict] + ] = None, + *, + permission: Optional[gag_permission.Permission] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gag_permission.Permission: + r"""Updates the permission. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_update_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdatePermissionRequest( + ) + + # Make the request + response = await client.update_permission(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.UpdatePermissionRequest, dict]]): + The request object. Request to update the ``Permission``. + permission (:class:`google.ai.generativelanguage_v1alpha.types.Permission`): + Required. The permission to update. + + The permission's ``name`` field is used to identify the + permission to update. + + This corresponds to the ``permission`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The list of fields to update. Accepted ones: + + - role (``Permission.role`` field) + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, corpus). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model, corpus) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([permission, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, permission_service.UpdatePermissionRequest): + request = permission_service.UpdatePermissionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if permission is not None: + request.permission = permission + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_permission + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("permission.name", request.permission.name),) + ), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_permission( + self, + request: Optional[ + Union[permission_service.DeletePermissionRequest, dict] + ] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Deletes the permission. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_delete_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeletePermissionRequest( + name="name_value", + ) + + # Make the request + await client.delete_permission(request=request) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.DeletePermissionRequest, dict]]): + The request object. Request to delete the ``Permission``. + name (:class:`str`): + Required. The resource name of the permission. Formats: + ``tunedModels/{tuned_model}/permissions/{permission}`` + ``corpora/{corpus}/permissions/{permission}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, permission_service.DeletePermissionRequest): + request = permission_service.DeletePermissionRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_permission + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def transfer_ownership( + self, + request: Optional[ + Union[permission_service.TransferOwnershipRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> permission_service.TransferOwnershipResponse: + r"""Transfers ownership of the tuned model. + This is the only way to change ownership of the tuned + model. The current owner will be downgraded to writer + role. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_transfer_ownership(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.TransferOwnershipRequest( + name="name_value", + email_address="email_address_value", + ) + + # Make the request + response = await client.transfer_ownership(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.TransferOwnershipRequest, dict]]): + The request object. Request to transfer the ownership of + the tuned model. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.TransferOwnershipResponse: + Response from TransferOwnership. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, permission_service.TransferOwnershipRequest): + request = permission_service.TransferOwnershipRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.transfer_ownership + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "PermissionServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("PermissionServiceAsyncClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/client.py new file mode 100644 index 000000000000..cf608eece1b1 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/client.py @@ -0,0 +1,1532 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import os +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) +import warnings + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.permission_service import pagers +from google.ai.generativelanguage_v1alpha.types import permission as gag_permission +from google.ai.generativelanguage_v1alpha.types import permission +from google.ai.generativelanguage_v1alpha.types import permission_service + +from .transports.base import DEFAULT_CLIENT_INFO, PermissionServiceTransport +from .transports.grpc import PermissionServiceGrpcTransport +from .transports.grpc_asyncio import PermissionServiceGrpcAsyncIOTransport +from .transports.rest import PermissionServiceRestTransport + + +class PermissionServiceClientMeta(type): + """Metaclass for the PermissionService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[PermissionServiceTransport]] + _transport_registry["grpc"] = PermissionServiceGrpcTransport + _transport_registry["grpc_asyncio"] = PermissionServiceGrpcAsyncIOTransport + _transport_registry["rest"] = PermissionServiceRestTransport + + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[PermissionServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class PermissionServiceClient(metaclass=PermissionServiceClientMeta): + """Provides methods for managing permissions to PaLM API + resources. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + _DEFAULT_ENDPOINT_TEMPLATE = "generativelanguage.{UNIVERSE_DOMAIN}" + _DEFAULT_UNIVERSE = "googleapis.com" + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PermissionServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PermissionServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> PermissionServiceTransport: + """Returns the transport used by the client instance. + + Returns: + PermissionServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def permission_path( + tuned_model: str, + permission: str, + ) -> str: + """Returns a fully-qualified permission string.""" + return "tunedModels/{tuned_model}/permissions/{permission}".format( + tuned_model=tuned_model, + permission=permission, + ) + + @staticmethod + def parse_permission_path(path: str) -> Dict[str, str]: + """Parses a permission path into its component segments.""" + m = re.match( + r"^tunedModels/(?P.+?)/permissions/(?P.+?)$", path + ) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path( + billing_account: str, + ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path( + folder: str, + ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format( + folder=folder, + ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path( + organization: str, + ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format( + organization=organization, + ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path( + project: str, + ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format( + project=project, + ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path( + project: str, + location: str, + ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Deprecated. Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + + warnings.warn( + "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.", + DeprecationWarning, + ) + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + @staticmethod + def _read_environment_variables(): + """Returns the environment variables used by the client. + + Returns: + Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE, + GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables. + + Raises: + ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not + any of ["true", "false"]. + google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT + is not any of ["auto", "never", "always"]. + """ + use_client_cert = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() + universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + + @staticmethod + def _get_client_cert_source(provided_cert_source, use_cert_flag): + """Return the client cert source to be used by the client. + + Args: + provided_cert_source (bytes): The client certificate source provided. + use_cert_flag (bool): A flag indicating whether to use the client certificate. + + Returns: + bytes or None: The client cert source to be used by the client. + """ + client_cert_source = None + if use_cert_flag: + if provided_cert_source: + client_cert_source = provided_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + return client_cert_source + + @staticmethod + def _get_api_endpoint( + api_override, client_cert_source, universe_domain, use_mtls_endpoint + ): + """Return the API endpoint used by the client. + + Args: + api_override (str): The API endpoint override. If specified, this is always + the return value of this function and the other arguments are not used. + client_cert_source (bytes): The client certificate source used by the client. + universe_domain (str): The universe domain used by the client. + use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters. + Possible values are "always", "auto", or "never". + + Returns: + str: The API endpoint to be used by the client. + """ + if api_override is not None: + api_endpoint = api_override + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + _default_universe = PermissionServiceClient._DEFAULT_UNIVERSE + if universe_domain != _default_universe: + raise MutualTLSChannelError( + f"mTLS is not supported in any universe other than {_default_universe}." + ) + api_endpoint = PermissionServiceClient.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = PermissionServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=universe_domain + ) + return api_endpoint + + @staticmethod + def _get_universe_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] + ) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = PermissionServiceClient._DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise ValueError("Universe Domain cannot be an empty string.") + return universe_domain + + def _validate_universe_domain(self): + """Validates client's and credentials' universe domains are consistent. + + Returns: + bool: True iff the configured universe domain is valid. + + Raises: + ValueError: If the configured universe domain is not valid. + """ + + # NOTE (b/349488459): universe validation is disabled until further notice. + return True + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used by the client instance. + """ + return self._universe_domain + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[ + str, + PermissionServiceTransport, + Callable[..., PermissionServiceTransport], + ] + ] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the permission service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,PermissionServiceTransport,Callable[..., PermissionServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PermissionServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that the ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client_options = client_options + if isinstance(self._client_options, dict): + self._client_options = client_options_lib.from_dict(self._client_options) + if self._client_options is None: + self._client_options = client_options_lib.ClientOptions() + self._client_options = cast( + client_options_lib.ClientOptions, self._client_options + ) + + universe_domain_opt = getattr(self._client_options, "universe_domain", None) + + ( + self._use_client_cert, + self._use_mtls_endpoint, + self._universe_domain_env, + ) = PermissionServiceClient._read_environment_variables() + self._client_cert_source = PermissionServiceClient._get_client_cert_source( + self._client_options.client_cert_source, self._use_client_cert + ) + self._universe_domain = PermissionServiceClient._get_universe_domain( + universe_domain_opt, self._universe_domain_env + ) + self._api_endpoint = None # updated below, depending on `transport` + + # Initialize the universe domain validation. + self._is_universe_domain_valid = False + + if CLIENT_LOGGING_SUPPORTED: # pragma: NO COVER + # Setup logging. + client_logging.initialize_logging() + + api_key_value = getattr(self._client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + transport_provided = isinstance(transport, PermissionServiceTransport) + if transport_provided: + # transport is a PermissionServiceTransport instance. + if credentials or self._client_options.credentials_file or api_key_value: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if self._client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = cast(PermissionServiceTransport, transport) + self._api_endpoint = self._transport.host + + self._api_endpoint = ( + self._api_endpoint + or PermissionServiceClient._get_api_endpoint( + self._client_options.api_endpoint, + self._client_cert_source, + self._universe_domain, + self._use_mtls_endpoint, + ) + ) + + if not transport_provided: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + + transport_init: Union[ + Type[PermissionServiceTransport], + Callable[..., PermissionServiceTransport], + ] = ( + PermissionServiceClient.get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., PermissionServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( + credentials=credentials, + credentials_file=self._client_options.credentials_file, + host=self._api_endpoint, + scopes=self._client_options.scopes, + client_cert_source_for_mtls=self._client_cert_source, + quota_project_id=self._client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=self._client_options.api_audience, + ) + + if "async" not in str(self._transport): + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.PermissionServiceClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "universeDomain": getattr( + self._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._transport._credentials).__module__}.{type(self._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "credentialsType": None, + }, + ) + + def create_permission( + self, + request: Optional[ + Union[permission_service.CreatePermissionRequest, dict] + ] = None, + *, + parent: Optional[str] = None, + permission: Optional[gag_permission.Permission] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gag_permission.Permission: + r"""Create a permission to a specific resource. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_create_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreatePermissionRequest( + parent="parent_value", + ) + + # Make the request + response = client.create_permission(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.CreatePermissionRequest, dict]): + The request object. Request to create a ``Permission``. + parent (str): + Required. The parent resource of the ``Permission``. + Formats: ``tunedModels/{tuned_model}`` + ``corpora/{corpus}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + permission (google.ai.generativelanguage_v1alpha.types.Permission): + Required. The permission to create. + This corresponds to the ``permission`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, corpus). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model, corpus) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, permission]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, permission_service.CreatePermissionRequest): + request = permission_service.CreatePermissionRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + if permission is not None: + request.permission = permission + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_permission] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_permission( + self, + request: Optional[Union[permission_service.GetPermissionRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> permission.Permission: + r"""Gets information about a specific Permission. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_get_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetPermissionRequest( + name="name_value", + ) + + # Make the request + response = client.get_permission(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.GetPermissionRequest, dict]): + The request object. Request for getting information about a specific + ``Permission``. + name (str): + Required. The resource name of the permission. + + Formats: + ``tunedModels/{tuned_model}/permissions/{permission}`` + ``corpora/{corpus}/permissions/{permission}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, corpus). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model, corpus) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, permission_service.GetPermissionRequest): + request = permission_service.GetPermissionRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_permission] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def list_permissions( + self, + request: Optional[ + Union[permission_service.ListPermissionsRequest, dict] + ] = None, + *, + parent: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListPermissionsPager: + r"""Lists permissions for the specific resource. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_list_permissions(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListPermissionsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_permissions(request=request) + + # Handle the response + for response in page_result: + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.ListPermissionsRequest, dict]): + The request object. Request for listing permissions. + parent (str): + Required. The parent resource of the permissions. + Formats: ``tunedModels/{tuned_model}`` + ``corpora/{corpus}`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.permission_service.pagers.ListPermissionsPager: + Response from ListPermissions containing a paginated list of + permissions. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, permission_service.ListPermissionsRequest): + request = permission_service.ListPermissionsRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_permissions] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListPermissionsPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def update_permission( + self, + request: Optional[ + Union[permission_service.UpdatePermissionRequest, dict] + ] = None, + *, + permission: Optional[gag_permission.Permission] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gag_permission.Permission: + r"""Updates the permission. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_update_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdatePermissionRequest( + ) + + # Make the request + response = client.update_permission(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.UpdatePermissionRequest, dict]): + The request object. Request to update the ``Permission``. + permission (google.ai.generativelanguage_v1alpha.types.Permission): + Required. The permission to update. + + The permission's ``name`` field is used to identify the + permission to update. + + This corresponds to the ``permission`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The list of fields to update. Accepted ones: + + - role (``Permission.role`` field) + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, corpus). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model, corpus) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([permission, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, permission_service.UpdatePermissionRequest): + request = permission_service.UpdatePermissionRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if permission is not None: + request.permission = permission + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_permission] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("permission.name", request.permission.name),) + ), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_permission( + self, + request: Optional[ + Union[permission_service.DeletePermissionRequest, dict] + ] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Deletes the permission. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_delete_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeletePermissionRequest( + name="name_value", + ) + + # Make the request + client.delete_permission(request=request) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.DeletePermissionRequest, dict]): + The request object. Request to delete the ``Permission``. + name (str): + Required. The resource name of the permission. Formats: + ``tunedModels/{tuned_model}/permissions/{permission}`` + ``corpora/{corpus}/permissions/{permission}`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, permission_service.DeletePermissionRequest): + request = permission_service.DeletePermissionRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_permission] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def transfer_ownership( + self, + request: Optional[ + Union[permission_service.TransferOwnershipRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> permission_service.TransferOwnershipResponse: + r"""Transfers ownership of the tuned model. + This is the only way to change ownership of the tuned + model. The current owner will be downgraded to writer + role. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_transfer_ownership(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.TransferOwnershipRequest( + name="name_value", + email_address="email_address_value", + ) + + # Make the request + response = client.transfer_ownership(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.TransferOwnershipRequest, dict]): + The request object. Request to transfer the ownership of + the tuned model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.TransferOwnershipResponse: + Response from TransferOwnership. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, permission_service.TransferOwnershipRequest): + request = permission_service.TransferOwnershipRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.transfer_ownership] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def __enter__(self) -> "PermissionServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("PermissionServiceClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/pagers.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/pagers.py new file mode 100644 index 000000000000..939d51890314 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/pagers.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Iterator, + Optional, + Sequence, + Tuple, + Union, +) + +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + +from google.ai.generativelanguage_v1alpha.types import permission, permission_service + + +class ListPermissionsPager: + """A pager for iterating through ``list_permissions`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListPermissionsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``permissions`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListPermissions`` requests and continue to iterate + through the ``permissions`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListPermissionsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., permission_service.ListPermissionsResponse], + request: permission_service.ListPermissionsRequest, + response: permission_service.ListPermissionsResponse, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListPermissionsRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListPermissionsResponse): + The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = permission_service.ListPermissionsRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[permission_service.ListPermissionsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __iter__(self) -> Iterator[permission.Permission]: + for page in self.pages: + yield from page.permissions + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListPermissionsAsyncPager: + """A pager for iterating through ``list_permissions`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListPermissionsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``permissions`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListPermissions`` requests and continue to iterate + through the ``permissions`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListPermissionsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[permission_service.ListPermissionsResponse]], + request: permission_service.ListPermissionsRequest, + response: permission_service.ListPermissionsResponse, + *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListPermissionsRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListPermissionsResponse): + The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = permission_service.ListPermissionsRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[permission_service.ListPermissionsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __aiter__(self) -> AsyncIterator[permission.Permission]: + async def async_generator(): + async for page in self.pages: + for response in page.permissions: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/README.rst b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/README.rst new file mode 100644 index 000000000000..240b1d751a3a --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/README.rst @@ -0,0 +1,9 @@ + +transport inheritance structure +_______________________________ + +`PermissionServiceTransport` is the ABC for all transports. +- public child `PermissionServiceGrpcTransport` for sync gRPC transport (defined in `grpc.py`). +- public child `PermissionServiceGrpcAsyncIOTransport` for async gRPC transport (defined in `grpc_asyncio.py`). +- private child `_BasePermissionServiceRestTransport` for base REST transport with inner classes `_BaseMETHOD` (defined in `rest_base.py`). +- public child `PermissionServiceRestTransport` for sync REST transport with inner classes `METHOD` derived from the parent's corresponding `_BaseMETHOD` classes (defined in `rest.py`). diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/__init__.py new file mode 100644 index 000000000000..05310b85abe5 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/__init__.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import PermissionServiceTransport +from .grpc import PermissionServiceGrpcTransport +from .grpc_asyncio import PermissionServiceGrpcAsyncIOTransport +from .rest import PermissionServiceRestInterceptor, PermissionServiceRestTransport + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[PermissionServiceTransport]] +_transport_registry["grpc"] = PermissionServiceGrpcTransport +_transport_registry["grpc_asyncio"] = PermissionServiceGrpcAsyncIOTransport +_transport_registry["rest"] = PermissionServiceRestTransport + +__all__ = ( + "PermissionServiceTransport", + "PermissionServiceGrpcTransport", + "PermissionServiceGrpcAsyncIOTransport", + "PermissionServiceRestTransport", + "PermissionServiceRestInterceptor", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/base.py new file mode 100644 index 000000000000..b7201aa373aa --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/base.py @@ -0,0 +1,272 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account # type: ignore +from google.protobuf import empty_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version +from google.ai.generativelanguage_v1alpha.types import permission as gag_permission +from google.ai.generativelanguage_v1alpha.types import permission +from google.ai.generativelanguage_v1alpha.types import permission_service + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +class PermissionServiceTransport(abc.ABC): + """Abstract transport class for PermissionService.""" + + AUTH_SCOPES = () + + DEFAULT_HOST: str = "generativelanguage.googleapis.com" + + def __init__( + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + if not hasattr(self, "_ignore_credentials"): + self._ignore_credentials: bool = False + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) + elif credentials is None and not self._ignore_credentials: + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + @property + def host(self): + return self._host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_permission: gapic_v1.method.wrap_method( + self.create_permission, + default_timeout=None, + client_info=client_info, + ), + self.get_permission: gapic_v1.method.wrap_method( + self.get_permission, + default_timeout=None, + client_info=client_info, + ), + self.list_permissions: gapic_v1.method.wrap_method( + self.list_permissions, + default_timeout=None, + client_info=client_info, + ), + self.update_permission: gapic_v1.method.wrap_method( + self.update_permission, + default_timeout=None, + client_info=client_info, + ), + self.delete_permission: gapic_v1.method.wrap_method( + self.delete_permission, + default_timeout=None, + client_info=client_info, + ), + self.transfer_ownership: gapic_v1.method.wrap_method( + self.transfer_ownership, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: gapic_v1.method.wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: gapic_v1.method.wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def create_permission( + self, + ) -> Callable[ + [permission_service.CreatePermissionRequest], + Union[gag_permission.Permission, Awaitable[gag_permission.Permission]], + ]: + raise NotImplementedError() + + @property + def get_permission( + self, + ) -> Callable[ + [permission_service.GetPermissionRequest], + Union[permission.Permission, Awaitable[permission.Permission]], + ]: + raise NotImplementedError() + + @property + def list_permissions( + self, + ) -> Callable[ + [permission_service.ListPermissionsRequest], + Union[ + permission_service.ListPermissionsResponse, + Awaitable[permission_service.ListPermissionsResponse], + ], + ]: + raise NotImplementedError() + + @property + def update_permission( + self, + ) -> Callable[ + [permission_service.UpdatePermissionRequest], + Union[gag_permission.Permission, Awaitable[gag_permission.Permission]], + ]: + raise NotImplementedError() + + @property + def delete_permission( + self, + ) -> Callable[ + [permission_service.DeletePermissionRequest], + Union[empty_pb2.Empty, Awaitable[empty_pb2.Empty]], + ]: + raise NotImplementedError() + + @property + def transfer_ownership( + self, + ) -> Callable[ + [permission_service.TransferOwnershipRequest], + Union[ + permission_service.TransferOwnershipResponse, + Awaitable[permission_service.TransferOwnershipResponse], + ], + ]: + raise NotImplementedError() + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], + Union[ + operations_pb2.ListOperationsResponse, + Awaitable[operations_pb2.ListOperationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_operation( + self, + ) -> Callable[ + [operations_pb2.GetOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ("PermissionServiceTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/grpc.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/grpc.py new file mode 100644 index 000000000000..c5f22f7aa880 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/grpc.py @@ -0,0 +1,541 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import logging as std_logging +import pickle +from typing import Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import gapic_v1, grpc_helpers +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import permission as gag_permission +from google.ai.generativelanguage_v1alpha.types import permission +from google.ai.generativelanguage_v1alpha.types import permission_service + +from .base import DEFAULT_CLIENT_INFO, PermissionServiceTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientInterceptor(grpc.UnaryUnaryClientInterceptor): # pragma: NO COVER + def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": client_call_details.method, + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + + response = continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = response.result() + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response for {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": client_call_details.method, + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class PermissionServiceGrpcTransport(PermissionServiceTransport): + """gRPC backend transport for PermissionService. + + Provides methods for managing permissions to PaLM API + resources. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, grpc.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientInterceptor() + self._logged_channel = grpc.intercept_channel( + self._grpc_channel, self._interceptor + ) + + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service.""" + return self._grpc_channel + + @property + def create_permission( + self, + ) -> Callable[ + [permission_service.CreatePermissionRequest], gag_permission.Permission + ]: + r"""Return a callable for the create permission method over gRPC. + + Create a permission to a specific resource. + + Returns: + Callable[[~.CreatePermissionRequest], + ~.Permission]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_permission" not in self._stubs: + self._stubs["create_permission"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.PermissionService/CreatePermission", + request_serializer=permission_service.CreatePermissionRequest.serialize, + response_deserializer=gag_permission.Permission.deserialize, + ) + return self._stubs["create_permission"] + + @property + def get_permission( + self, + ) -> Callable[[permission_service.GetPermissionRequest], permission.Permission]: + r"""Return a callable for the get permission method over gRPC. + + Gets information about a specific Permission. + + Returns: + Callable[[~.GetPermissionRequest], + ~.Permission]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_permission" not in self._stubs: + self._stubs["get_permission"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.PermissionService/GetPermission", + request_serializer=permission_service.GetPermissionRequest.serialize, + response_deserializer=permission.Permission.deserialize, + ) + return self._stubs["get_permission"] + + @property + def list_permissions( + self, + ) -> Callable[ + [permission_service.ListPermissionsRequest], + permission_service.ListPermissionsResponse, + ]: + r"""Return a callable for the list permissions method over gRPC. + + Lists permissions for the specific resource. + + Returns: + Callable[[~.ListPermissionsRequest], + ~.ListPermissionsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_permissions" not in self._stubs: + self._stubs["list_permissions"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.PermissionService/ListPermissions", + request_serializer=permission_service.ListPermissionsRequest.serialize, + response_deserializer=permission_service.ListPermissionsResponse.deserialize, + ) + return self._stubs["list_permissions"] + + @property + def update_permission( + self, + ) -> Callable[ + [permission_service.UpdatePermissionRequest], gag_permission.Permission + ]: + r"""Return a callable for the update permission method over gRPC. + + Updates the permission. + + Returns: + Callable[[~.UpdatePermissionRequest], + ~.Permission]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_permission" not in self._stubs: + self._stubs["update_permission"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.PermissionService/UpdatePermission", + request_serializer=permission_service.UpdatePermissionRequest.serialize, + response_deserializer=gag_permission.Permission.deserialize, + ) + return self._stubs["update_permission"] + + @property + def delete_permission( + self, + ) -> Callable[[permission_service.DeletePermissionRequest], empty_pb2.Empty]: + r"""Return a callable for the delete permission method over gRPC. + + Deletes the permission. + + Returns: + Callable[[~.DeletePermissionRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_permission" not in self._stubs: + self._stubs["delete_permission"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.PermissionService/DeletePermission", + request_serializer=permission_service.DeletePermissionRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["delete_permission"] + + @property + def transfer_ownership( + self, + ) -> Callable[ + [permission_service.TransferOwnershipRequest], + permission_service.TransferOwnershipResponse, + ]: + r"""Return a callable for the transfer ownership method over gRPC. + + Transfers ownership of the tuned model. + This is the only way to change ownership of the tuned + model. The current owner will be downgraded to writer + role. + + Returns: + Callable[[~.TransferOwnershipRequest], + ~.TransferOwnershipResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "transfer_ownership" not in self._stubs: + self._stubs["transfer_ownership"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.PermissionService/TransferOwnership", + request_serializer=permission_service.TransferOwnershipRequest.serialize, + response_deserializer=permission_service.TransferOwnershipResponse.deserialize, + ) + return self._stubs["transfer_ownership"] + + def close(self): + self._logged_channel.close() + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ("PermissionServiceGrpcTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/grpc_asyncio.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..134d9d327043 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/grpc_asyncio.py @@ -0,0 +1,604 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import json +import logging as std_logging +import pickle +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, grpc_helpers_async +from google.api_core import retry_async as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +from grpc.experimental import aio # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import permission as gag_permission +from google.ai.generativelanguage_v1alpha.types import permission +from google.ai.generativelanguage_v1alpha.types import permission_service + +from .base import DEFAULT_CLIENT_INFO, PermissionServiceTransport +from .grpc import PermissionServiceGrpcTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientAIOInterceptor( + grpc.aio.UnaryUnaryClientInterceptor +): # pragma: NO COVER + async def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": str(client_call_details.method), + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + response = await continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = await response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = await response + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response to rpc {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": str(client_call_details.method), + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class PermissionServiceGrpcAsyncIOTransport(PermissionServiceTransport): + """gRPC AsyncIO backend transport for PermissionService. + + Provides methods for managing permissions to PaLM API + resources. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, aio.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientAIOInterceptor() + self._grpc_channel._unary_unary_interceptors.append(self._interceptor) + self._logged_channel = self._grpc_channel + self._wrap_with_kind = ( + "kind" in inspect.signature(gapic_v1.method_async.wrap_method).parameters + ) + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def create_permission( + self, + ) -> Callable[ + [permission_service.CreatePermissionRequest], + Awaitable[gag_permission.Permission], + ]: + r"""Return a callable for the create permission method over gRPC. + + Create a permission to a specific resource. + + Returns: + Callable[[~.CreatePermissionRequest], + Awaitable[~.Permission]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_permission" not in self._stubs: + self._stubs["create_permission"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.PermissionService/CreatePermission", + request_serializer=permission_service.CreatePermissionRequest.serialize, + response_deserializer=gag_permission.Permission.deserialize, + ) + return self._stubs["create_permission"] + + @property + def get_permission( + self, + ) -> Callable[ + [permission_service.GetPermissionRequest], Awaitable[permission.Permission] + ]: + r"""Return a callable for the get permission method over gRPC. + + Gets information about a specific Permission. + + Returns: + Callable[[~.GetPermissionRequest], + Awaitable[~.Permission]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_permission" not in self._stubs: + self._stubs["get_permission"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.PermissionService/GetPermission", + request_serializer=permission_service.GetPermissionRequest.serialize, + response_deserializer=permission.Permission.deserialize, + ) + return self._stubs["get_permission"] + + @property + def list_permissions( + self, + ) -> Callable[ + [permission_service.ListPermissionsRequest], + Awaitable[permission_service.ListPermissionsResponse], + ]: + r"""Return a callable for the list permissions method over gRPC. + + Lists permissions for the specific resource. + + Returns: + Callable[[~.ListPermissionsRequest], + Awaitable[~.ListPermissionsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_permissions" not in self._stubs: + self._stubs["list_permissions"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.PermissionService/ListPermissions", + request_serializer=permission_service.ListPermissionsRequest.serialize, + response_deserializer=permission_service.ListPermissionsResponse.deserialize, + ) + return self._stubs["list_permissions"] + + @property + def update_permission( + self, + ) -> Callable[ + [permission_service.UpdatePermissionRequest], + Awaitable[gag_permission.Permission], + ]: + r"""Return a callable for the update permission method over gRPC. + + Updates the permission. + + Returns: + Callable[[~.UpdatePermissionRequest], + Awaitable[~.Permission]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_permission" not in self._stubs: + self._stubs["update_permission"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.PermissionService/UpdatePermission", + request_serializer=permission_service.UpdatePermissionRequest.serialize, + response_deserializer=gag_permission.Permission.deserialize, + ) + return self._stubs["update_permission"] + + @property + def delete_permission( + self, + ) -> Callable[ + [permission_service.DeletePermissionRequest], Awaitable[empty_pb2.Empty] + ]: + r"""Return a callable for the delete permission method over gRPC. + + Deletes the permission. + + Returns: + Callable[[~.DeletePermissionRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_permission" not in self._stubs: + self._stubs["delete_permission"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.PermissionService/DeletePermission", + request_serializer=permission_service.DeletePermissionRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["delete_permission"] + + @property + def transfer_ownership( + self, + ) -> Callable[ + [permission_service.TransferOwnershipRequest], + Awaitable[permission_service.TransferOwnershipResponse], + ]: + r"""Return a callable for the transfer ownership method over gRPC. + + Transfers ownership of the tuned model. + This is the only way to change ownership of the tuned + model. The current owner will be downgraded to writer + role. + + Returns: + Callable[[~.TransferOwnershipRequest], + Awaitable[~.TransferOwnershipResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "transfer_ownership" not in self._stubs: + self._stubs["transfer_ownership"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.PermissionService/TransferOwnership", + request_serializer=permission_service.TransferOwnershipRequest.serialize, + response_deserializer=permission_service.TransferOwnershipResponse.deserialize, + ) + return self._stubs["transfer_ownership"] + + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_permission: self._wrap_method( + self.create_permission, + default_timeout=None, + client_info=client_info, + ), + self.get_permission: self._wrap_method( + self.get_permission, + default_timeout=None, + client_info=client_info, + ), + self.list_permissions: self._wrap_method( + self.list_permissions, + default_timeout=None, + client_info=client_info, + ), + self.update_permission: self._wrap_method( + self.update_permission, + default_timeout=None, + client_info=client_info, + ), + self.delete_permission: self._wrap_method( + self.delete_permission, + default_timeout=None, + client_info=client_info, + ), + self.transfer_ownership: self._wrap_method( + self.transfer_ownership, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: self._wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: self._wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def _wrap_method(self, func, *args, **kwargs): + if self._wrap_with_kind: # pragma: NO COVER + kwargs["kind"] = self.kind + return gapic_v1.method_async.wrap_method(func, *args, **kwargs) + + def close(self): + return self._logged_channel.close() + + @property + def kind(self) -> str: + return "grpc_asyncio" + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + +__all__ = ("PermissionServiceGrpcAsyncIOTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/rest.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/rest.py new file mode 100644 index 000000000000..559cf0c08038 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/rest.py @@ -0,0 +1,1671 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +import json # type: ignore +import logging +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, rest_helpers, rest_streaming +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.requests import AuthorizedSession # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import json_format +from requests import __version__ as requests_version + +from google.ai.generativelanguage_v1alpha.types import permission as gag_permission +from google.ai.generativelanguage_v1alpha.types import permission +from google.ai.generativelanguage_v1alpha.types import permission_service + +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from .rest_base import _BasePermissionServiceRestTransport + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = logging.getLogger(__name__) + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=f"requests@{requests_version}", +) + + +class PermissionServiceRestInterceptor: + """Interceptor for PermissionService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the PermissionServiceRestTransport. + + .. code-block:: python + class MyCustomPermissionServiceInterceptor(PermissionServiceRestInterceptor): + def pre_create_permission(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_create_permission(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_delete_permission(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def pre_get_permission(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_permission(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_permissions(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_permissions(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_transfer_ownership(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_transfer_ownership(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_update_permission(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_update_permission(self, response): + logging.log(f"Received response: {response}") + return response + + transport = PermissionServiceRestTransport(interceptor=MyCustomPermissionServiceInterceptor()) + client = PermissionServiceClient(transport=transport) + + + """ + + def pre_create_permission( + self, + request: permission_service.CreatePermissionRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + permission_service.CreatePermissionRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for create_permission + + Override in a subclass to manipulate the request or metadata + before they are sent to the PermissionService server. + """ + return request, metadata + + def post_create_permission( + self, response: gag_permission.Permission + ) -> gag_permission.Permission: + """Post-rpc interceptor for create_permission + + Override in a subclass to manipulate the response + after it is returned by the PermissionService server but before + it is returned to user code. + """ + return response + + def pre_delete_permission( + self, + request: permission_service.DeletePermissionRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + permission_service.DeletePermissionRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for delete_permission + + Override in a subclass to manipulate the request or metadata + before they are sent to the PermissionService server. + """ + return request, metadata + + def pre_get_permission( + self, + request: permission_service.GetPermissionRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + permission_service.GetPermissionRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for get_permission + + Override in a subclass to manipulate the request or metadata + before they are sent to the PermissionService server. + """ + return request, metadata + + def post_get_permission( + self, response: permission.Permission + ) -> permission.Permission: + """Post-rpc interceptor for get_permission + + Override in a subclass to manipulate the response + after it is returned by the PermissionService server but before + it is returned to user code. + """ + return response + + def pre_list_permissions( + self, + request: permission_service.ListPermissionsRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + permission_service.ListPermissionsRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for list_permissions + + Override in a subclass to manipulate the request or metadata + before they are sent to the PermissionService server. + """ + return request, metadata + + def post_list_permissions( + self, response: permission_service.ListPermissionsResponse + ) -> permission_service.ListPermissionsResponse: + """Post-rpc interceptor for list_permissions + + Override in a subclass to manipulate the response + after it is returned by the PermissionService server but before + it is returned to user code. + """ + return response + + def pre_transfer_ownership( + self, + request: permission_service.TransferOwnershipRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + permission_service.TransferOwnershipRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for transfer_ownership + + Override in a subclass to manipulate the request or metadata + before they are sent to the PermissionService server. + """ + return request, metadata + + def post_transfer_ownership( + self, response: permission_service.TransferOwnershipResponse + ) -> permission_service.TransferOwnershipResponse: + """Post-rpc interceptor for transfer_ownership + + Override in a subclass to manipulate the response + after it is returned by the PermissionService server but before + it is returned to user code. + """ + return response + + def pre_update_permission( + self, + request: permission_service.UpdatePermissionRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + permission_service.UpdatePermissionRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for update_permission + + Override in a subclass to manipulate the request or metadata + before they are sent to the PermissionService server. + """ + return request, metadata + + def post_update_permission( + self, response: gag_permission.Permission + ) -> gag_permission.Permission: + """Post-rpc interceptor for update_permission + + Override in a subclass to manipulate the response + after it is returned by the PermissionService server but before + it is returned to user code. + """ + return response + + def pre_get_operation( + self, + request: operations_pb2.GetOperationRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.GetOperationRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for get_operation + + Override in a subclass to manipulate the request or metadata + before they are sent to the PermissionService server. + """ + return request, metadata + + def post_get_operation( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for get_operation + + Override in a subclass to manipulate the response + after it is returned by the PermissionService server but before + it is returned to user code. + """ + return response + + def pre_list_operations( + self, + request: operations_pb2.ListOperationsRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.ListOperationsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for list_operations + + Override in a subclass to manipulate the request or metadata + before they are sent to the PermissionService server. + """ + return request, metadata + + def post_list_operations( + self, response: operations_pb2.ListOperationsResponse + ) -> operations_pb2.ListOperationsResponse: + """Post-rpc interceptor for list_operations + + Override in a subclass to manipulate the response + after it is returned by the PermissionService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class PermissionServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: PermissionServiceRestInterceptor + + +class PermissionServiceRestTransport(_BasePermissionServiceRestTransport): + """REST backend synchronous transport for PermissionService. + + Provides methods for managing permissions to PaLM API + resources. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[PermissionServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + url_scheme=url_scheme, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or PermissionServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _CreatePermission( + _BasePermissionServiceRestTransport._BaseCreatePermission, + PermissionServiceRestStub, + ): + def __hash__(self): + return hash("PermissionServiceRestTransport.CreatePermission") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: permission_service.CreatePermissionRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gag_permission.Permission: + r"""Call the create permission method over HTTP. + + Args: + request (~.permission_service.CreatePermissionRequest): + The request object. Request to create a ``Permission``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.gag_permission.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, corpus). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model, corpus) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + + http_options = ( + _BasePermissionServiceRestTransport._BaseCreatePermission._get_http_options() + ) + + request, metadata = self._interceptor.pre_create_permission( + request, metadata + ) + transcoded_request = _BasePermissionServiceRestTransport._BaseCreatePermission._get_transcoded_request( + http_options, request + ) + + body = _BasePermissionServiceRestTransport._BaseCreatePermission._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BasePermissionServiceRestTransport._BaseCreatePermission._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.PermissionServiceClient.CreatePermission", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": "CreatePermission", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = PermissionServiceRestTransport._CreatePermission._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = gag_permission.Permission() + pb_resp = gag_permission.Permission.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_create_permission(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = gag_permission.Permission.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.PermissionServiceClient.create_permission", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": "CreatePermission", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _DeletePermission( + _BasePermissionServiceRestTransport._BaseDeletePermission, + PermissionServiceRestStub, + ): + def __hash__(self): + return hash("PermissionServiceRestTransport.DeletePermission") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: permission_service.DeletePermissionRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ): + r"""Call the delete permission method over HTTP. + + Args: + request (~.permission_service.DeletePermissionRequest): + The request object. Request to delete the ``Permission``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + + http_options = ( + _BasePermissionServiceRestTransport._BaseDeletePermission._get_http_options() + ) + + request, metadata = self._interceptor.pre_delete_permission( + request, metadata + ) + transcoded_request = _BasePermissionServiceRestTransport._BaseDeletePermission._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BasePermissionServiceRestTransport._BaseDeletePermission._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.PermissionServiceClient.DeletePermission", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": "DeletePermission", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = PermissionServiceRestTransport._DeletePermission._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + class _GetPermission( + _BasePermissionServiceRestTransport._BaseGetPermission, + PermissionServiceRestStub, + ): + def __hash__(self): + return hash("PermissionServiceRestTransport.GetPermission") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: permission_service.GetPermissionRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> permission.Permission: + r"""Call the get permission method over HTTP. + + Args: + request (~.permission_service.GetPermissionRequest): + The request object. Request for getting information about a specific + ``Permission``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.permission.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, corpus). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model, corpus) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + + http_options = ( + _BasePermissionServiceRestTransport._BaseGetPermission._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_permission(request, metadata) + transcoded_request = _BasePermissionServiceRestTransport._BaseGetPermission._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BasePermissionServiceRestTransport._BaseGetPermission._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.PermissionServiceClient.GetPermission", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": "GetPermission", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = PermissionServiceRestTransport._GetPermission._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = permission.Permission() + pb_resp = permission.Permission.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_get_permission(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = permission.Permission.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.PermissionServiceClient.get_permission", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": "GetPermission", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _ListPermissions( + _BasePermissionServiceRestTransport._BaseListPermissions, + PermissionServiceRestStub, + ): + def __hash__(self): + return hash("PermissionServiceRestTransport.ListPermissions") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: permission_service.ListPermissionsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> permission_service.ListPermissionsResponse: + r"""Call the list permissions method over HTTP. + + Args: + request (~.permission_service.ListPermissionsRequest): + The request object. Request for listing permissions. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.permission_service.ListPermissionsResponse: + Response from ``ListPermissions`` containing a paginated + list of permissions. + + """ + + http_options = ( + _BasePermissionServiceRestTransport._BaseListPermissions._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_permissions( + request, metadata + ) + transcoded_request = _BasePermissionServiceRestTransport._BaseListPermissions._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BasePermissionServiceRestTransport._BaseListPermissions._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.PermissionServiceClient.ListPermissions", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": "ListPermissions", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = PermissionServiceRestTransport._ListPermissions._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = permission_service.ListPermissionsResponse() + pb_resp = permission_service.ListPermissionsResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_list_permissions(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = ( + permission_service.ListPermissionsResponse.to_json(response) + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.PermissionServiceClient.list_permissions", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": "ListPermissions", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _TransferOwnership( + _BasePermissionServiceRestTransport._BaseTransferOwnership, + PermissionServiceRestStub, + ): + def __hash__(self): + return hash("PermissionServiceRestTransport.TransferOwnership") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: permission_service.TransferOwnershipRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> permission_service.TransferOwnershipResponse: + r"""Call the transfer ownership method over HTTP. + + Args: + request (~.permission_service.TransferOwnershipRequest): + The request object. Request to transfer the ownership of + the tuned model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.permission_service.TransferOwnershipResponse: + Response from ``TransferOwnership``. + """ + + http_options = ( + _BasePermissionServiceRestTransport._BaseTransferOwnership._get_http_options() + ) + + request, metadata = self._interceptor.pre_transfer_ownership( + request, metadata + ) + transcoded_request = _BasePermissionServiceRestTransport._BaseTransferOwnership._get_transcoded_request( + http_options, request + ) + + body = _BasePermissionServiceRestTransport._BaseTransferOwnership._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BasePermissionServiceRestTransport._BaseTransferOwnership._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.PermissionServiceClient.TransferOwnership", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": "TransferOwnership", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = PermissionServiceRestTransport._TransferOwnership._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = permission_service.TransferOwnershipResponse() + pb_resp = permission_service.TransferOwnershipResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_transfer_ownership(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = ( + permission_service.TransferOwnershipResponse.to_json(response) + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.PermissionServiceClient.transfer_ownership", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": "TransferOwnership", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _UpdatePermission( + _BasePermissionServiceRestTransport._BaseUpdatePermission, + PermissionServiceRestStub, + ): + def __hash__(self): + return hash("PermissionServiceRestTransport.UpdatePermission") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: permission_service.UpdatePermissionRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> gag_permission.Permission: + r"""Call the update permission method over HTTP. + + Args: + request (~.permission_service.UpdatePermissionRequest): + The request object. Request to update the ``Permission``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.gag_permission.Permission: + Permission resource grants user, + group or the rest of the world access to + the PaLM API resource (e.g. a tuned + model, corpus). + + A role is a collection of permitted + operations that allows users to perform + specific actions on PaLM API resources. + To make them available to users, groups, + or service accounts, you assign roles. + When you assign a role, you grant + permissions that the role contains. + + There are three concentric roles. Each + role is a superset of the previous + role's permitted operations: + + - reader can use the resource (e.g. + tuned model, corpus) for inference + - writer has reader's permissions and + additionally can edit and share + - owner has writer's permissions and + additionally can delete + + """ + + http_options = ( + _BasePermissionServiceRestTransport._BaseUpdatePermission._get_http_options() + ) + + request, metadata = self._interceptor.pre_update_permission( + request, metadata + ) + transcoded_request = _BasePermissionServiceRestTransport._BaseUpdatePermission._get_transcoded_request( + http_options, request + ) + + body = _BasePermissionServiceRestTransport._BaseUpdatePermission._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BasePermissionServiceRestTransport._BaseUpdatePermission._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.PermissionServiceClient.UpdatePermission", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": "UpdatePermission", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = PermissionServiceRestTransport._UpdatePermission._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = gag_permission.Permission() + pb_resp = gag_permission.Permission.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_update_permission(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = gag_permission.Permission.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.PermissionServiceClient.update_permission", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": "UpdatePermission", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + @property + def create_permission( + self, + ) -> Callable[ + [permission_service.CreatePermissionRequest], gag_permission.Permission + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CreatePermission(self._session, self._host, self._interceptor) # type: ignore + + @property + def delete_permission( + self, + ) -> Callable[[permission_service.DeletePermissionRequest], empty_pb2.Empty]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DeletePermission(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_permission( + self, + ) -> Callable[[permission_service.GetPermissionRequest], permission.Permission]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetPermission(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_permissions( + self, + ) -> Callable[ + [permission_service.ListPermissionsRequest], + permission_service.ListPermissionsResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListPermissions(self._session, self._host, self._interceptor) # type: ignore + + @property + def transfer_ownership( + self, + ) -> Callable[ + [permission_service.TransferOwnershipRequest], + permission_service.TransferOwnershipResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._TransferOwnership(self._session, self._host, self._interceptor) # type: ignore + + @property + def update_permission( + self, + ) -> Callable[ + [permission_service.UpdatePermissionRequest], gag_permission.Permission + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._UpdatePermission(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_operation(self): + return self._GetOperation(self._session, self._host, self._interceptor) # type: ignore + + class _GetOperation( + _BasePermissionServiceRestTransport._BaseGetOperation, PermissionServiceRestStub + ): + def __hash__(self): + return hash("PermissionServiceRestTransport.GetOperation") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.GetOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Call the get operation method over HTTP. + + Args: + request (operations_pb2.GetOperationRequest): + The request object for GetOperation method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.Operation: Response from GetOperation method. + """ + + http_options = ( + _BasePermissionServiceRestTransport._BaseGetOperation._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_operation(request, metadata) + transcoded_request = _BasePermissionServiceRestTransport._BaseGetOperation._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BasePermissionServiceRestTransport._BaseGetOperation._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.PermissionServiceClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": "GetOperation", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = PermissionServiceRestTransport._GetOperation._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.Operation() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_get_operation(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.PermissionServiceAsyncClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": "GetOperation", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def list_operations(self): + return self._ListOperations(self._session, self._host, self._interceptor) # type: ignore + + class _ListOperations( + _BasePermissionServiceRestTransport._BaseListOperations, + PermissionServiceRestStub, + ): + def __hash__(self): + return hash("PermissionServiceRestTransport.ListOperations") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.ListOperationsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Call the list operations method over HTTP. + + Args: + request (operations_pb2.ListOperationsRequest): + The request object for ListOperations method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.ListOperationsResponse: Response from ListOperations method. + """ + + http_options = ( + _BasePermissionServiceRestTransport._BaseListOperations._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_operations(request, metadata) + transcoded_request = _BasePermissionServiceRestTransport._BaseListOperations._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BasePermissionServiceRestTransport._BaseListOperations._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.PermissionServiceClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": "ListOperations", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = PermissionServiceRestTransport._ListOperations._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.ListOperationsResponse() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_list_operations(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.PermissionServiceAsyncClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PermissionService", + "rpcName": "ListOperations", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("PermissionServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/rest_base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/rest_base.py new file mode 100644 index 000000000000..56e09878b023 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/permission_service/transports/rest_base.py @@ -0,0 +1,493 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json # type: ignore +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1, path_template +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import json_format + +from google.ai.generativelanguage_v1alpha.types import permission as gag_permission +from google.ai.generativelanguage_v1alpha.types import permission +from google.ai.generativelanguage_v1alpha.types import permission_service + +from .base import DEFAULT_CLIENT_INFO, PermissionServiceTransport + + +class _BasePermissionServiceRestTransport(PermissionServiceTransport): + """Base REST backend transport for PermissionService. + + Note: This class is not meant to be used directly. Use its sync and + async sub-classes instead. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[Any] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[Any]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + class _BaseCreatePermission: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{parent=tunedModels/*}/permissions", + "body": "permission", + }, + { + "method": "post", + "uri": "/v1alpha/{parent=corpora/*}/permissions", + "body": "permission", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = permission_service.CreatePermissionRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BasePermissionServiceRestTransport._BaseCreatePermission._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseDeletePermission: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1alpha/{name=tunedModels/*/permissions/*}", + }, + { + "method": "delete", + "uri": "/v1alpha/{name=corpora/*/permissions/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = permission_service.DeletePermissionRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BasePermissionServiceRestTransport._BaseDeletePermission._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetPermission: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*/permissions/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=corpora/*/permissions/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = permission_service.GetPermissionRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BasePermissionServiceRestTransport._BaseGetPermission._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseListPermissions: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{parent=tunedModels/*}/permissions", + }, + { + "method": "get", + "uri": "/v1alpha/{parent=corpora/*}/permissions", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = permission_service.ListPermissionsRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BasePermissionServiceRestTransport._BaseListPermissions._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseTransferOwnership: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{name=tunedModels/*}:transferOwnership", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = permission_service.TransferOwnershipRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BasePermissionServiceRestTransport._BaseTransferOwnership._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseUpdatePermission: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + "updateMask": {}, + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "patch", + "uri": "/v1alpha/{permission.name=tunedModels/*/permissions/*}", + "body": "permission", + }, + { + "method": "patch", + "uri": "/v1alpha/{permission.name=corpora/*/permissions/*}", + "body": "permission", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = permission_service.UpdatePermissionRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BasePermissionServiceRestTransport._BaseUpdatePermission._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetOperation: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=generatedFiles/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*/operations/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + class _BaseListOperations: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*}/operations", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*}/operations", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + +__all__ = ("_BasePermissionServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/__init__.py new file mode 100644 index 000000000000..6c64cf5ad1c0 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .async_client import PredictionServiceAsyncClient +from .client import PredictionServiceClient + +__all__ = ( + "PredictionServiceClient", + "PredictionServiceAsyncClient", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/async_client.py new file mode 100644 index 000000000000..fe4f615d5413 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/async_client.py @@ -0,0 +1,535 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry_async as retries +from google.api_core.client_options import ClientOptions +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore + +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import struct_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.types import prediction_service + +from .client import PredictionServiceClient +from .transports.base import DEFAULT_CLIENT_INFO, PredictionServiceTransport +from .transports.grpc_asyncio import PredictionServiceGrpcAsyncIOTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class PredictionServiceAsyncClient: + """A service for online predictions and explanations.""" + + _client: PredictionServiceClient + + # Copy defaults from the synchronous client for use here. + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = PredictionServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = PredictionServiceClient.DEFAULT_MTLS_ENDPOINT + _DEFAULT_ENDPOINT_TEMPLATE = PredictionServiceClient._DEFAULT_ENDPOINT_TEMPLATE + _DEFAULT_UNIVERSE = PredictionServiceClient._DEFAULT_UNIVERSE + + model_path = staticmethod(PredictionServiceClient.model_path) + parse_model_path = staticmethod(PredictionServiceClient.parse_model_path) + common_billing_account_path = staticmethod( + PredictionServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + PredictionServiceClient.parse_common_billing_account_path + ) + common_folder_path = staticmethod(PredictionServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + PredictionServiceClient.parse_common_folder_path + ) + common_organization_path = staticmethod( + PredictionServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + PredictionServiceClient.parse_common_organization_path + ) + common_project_path = staticmethod(PredictionServiceClient.common_project_path) + parse_common_project_path = staticmethod( + PredictionServiceClient.parse_common_project_path + ) + common_location_path = staticmethod(PredictionServiceClient.common_location_path) + parse_common_location_path = staticmethod( + PredictionServiceClient.parse_common_location_path + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PredictionServiceAsyncClient: The constructed client. + """ + return PredictionServiceClient.from_service_account_info.__func__(PredictionServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PredictionServiceAsyncClient: The constructed client. + """ + return PredictionServiceClient.from_service_account_file.__func__(PredictionServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return PredictionServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> PredictionServiceTransport: + """Returns the transport used by the client instance. + + Returns: + PredictionServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._client._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used + by the client instance. + """ + return self._client._universe_domain + + get_transport_class = PredictionServiceClient.get_transport_class + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[ + str, + PredictionServiceTransport, + Callable[..., PredictionServiceTransport], + ] + ] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the prediction service async client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,PredictionServiceTransport,Callable[..., PredictionServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PredictionServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = PredictionServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.PredictionServiceAsyncClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PredictionService", + "universeDomain": getattr( + self._client._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._client._transport._credentials).__module__}.{type(self._client._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._client._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.PredictionService", + "credentialsType": None, + }, + ) + + async def predict( + self, + request: Optional[Union[prediction_service.PredictRequest, dict]] = None, + *, + model: Optional[str] = None, + instances: Optional[MutableSequence[struct_pb2.Value]] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> prediction_service.PredictResponse: + r"""Performs a prediction request. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_predict(): + # Create a client + client = generativelanguage_v1alpha.PredictionServiceAsyncClient() + + # Initialize request argument(s) + instances = generativelanguage_v1alpha.Value() + instances.null_value = "NULL_VALUE" + + request = generativelanguage_v1alpha.PredictRequest( + model="model_value", + instances=instances, + ) + + # Make the request + response = await client.predict(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.PredictRequest, dict]]): + The request object. Request message for + [PredictionService.Predict][google.ai.generativelanguage.v1alpha.PredictionService.Predict]. + model (:class:`str`): + Required. The name of the model for prediction. Format: + ``name=models/{model}``. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + instances (:class:`MutableSequence[google.protobuf.struct_pb2.Value]`): + Required. The instances that are the + input to the prediction call. + + This corresponds to the ``instances`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.PredictResponse: + Response message for [PredictionService.Predict]. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, instances]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.PredictRequest): + request = prediction_service.PredictRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if instances: + request.instances.extend(instances) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[self._client._transport.predict] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "PredictionServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("PredictionServiceAsyncClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/client.py new file mode 100644 index 000000000000..622ce9970e2d --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/client.py @@ -0,0 +1,929 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import os +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) +import warnings + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import struct_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.types import prediction_service + +from .transports.base import DEFAULT_CLIENT_INFO, PredictionServiceTransport +from .transports.grpc import PredictionServiceGrpcTransport +from .transports.grpc_asyncio import PredictionServiceGrpcAsyncIOTransport +from .transports.rest import PredictionServiceRestTransport + + +class PredictionServiceClientMeta(type): + """Metaclass for the PredictionService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[PredictionServiceTransport]] + _transport_registry["grpc"] = PredictionServiceGrpcTransport + _transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport + _transport_registry["rest"] = PredictionServiceRestTransport + + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[PredictionServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class PredictionServiceClient(metaclass=PredictionServiceClientMeta): + """A service for online predictions and explanations.""" + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + _DEFAULT_ENDPOINT_TEMPLATE = "generativelanguage.{UNIVERSE_DOMAIN}" + _DEFAULT_UNIVERSE = "googleapis.com" + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PredictionServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + PredictionServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> PredictionServiceTransport: + """Returns the transport used by the client instance. + + Returns: + PredictionServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def model_path( + model: str, + ) -> str: + """Returns a fully-qualified model string.""" + return "models/{model}".format( + model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parses a model path into its component segments.""" + m = re.match(r"^models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path( + billing_account: str, + ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path( + folder: str, + ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format( + folder=folder, + ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path( + organization: str, + ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format( + organization=organization, + ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path( + project: str, + ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format( + project=project, + ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path( + project: str, + location: str, + ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Deprecated. Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + + warnings.warn( + "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.", + DeprecationWarning, + ) + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + @staticmethod + def _read_environment_variables(): + """Returns the environment variables used by the client. + + Returns: + Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE, + GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables. + + Raises: + ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not + any of ["true", "false"]. + google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT + is not any of ["auto", "never", "always"]. + """ + use_client_cert = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() + universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + + @staticmethod + def _get_client_cert_source(provided_cert_source, use_cert_flag): + """Return the client cert source to be used by the client. + + Args: + provided_cert_source (bytes): The client certificate source provided. + use_cert_flag (bool): A flag indicating whether to use the client certificate. + + Returns: + bytes or None: The client cert source to be used by the client. + """ + client_cert_source = None + if use_cert_flag: + if provided_cert_source: + client_cert_source = provided_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + return client_cert_source + + @staticmethod + def _get_api_endpoint( + api_override, client_cert_source, universe_domain, use_mtls_endpoint + ): + """Return the API endpoint used by the client. + + Args: + api_override (str): The API endpoint override. If specified, this is always + the return value of this function and the other arguments are not used. + client_cert_source (bytes): The client certificate source used by the client. + universe_domain (str): The universe domain used by the client. + use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters. + Possible values are "always", "auto", or "never". + + Returns: + str: The API endpoint to be used by the client. + """ + if api_override is not None: + api_endpoint = api_override + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + _default_universe = PredictionServiceClient._DEFAULT_UNIVERSE + if universe_domain != _default_universe: + raise MutualTLSChannelError( + f"mTLS is not supported in any universe other than {_default_universe}." + ) + api_endpoint = PredictionServiceClient.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = PredictionServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=universe_domain + ) + return api_endpoint + + @staticmethod + def _get_universe_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] + ) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = PredictionServiceClient._DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise ValueError("Universe Domain cannot be an empty string.") + return universe_domain + + def _validate_universe_domain(self): + """Validates client's and credentials' universe domains are consistent. + + Returns: + bool: True iff the configured universe domain is valid. + + Raises: + ValueError: If the configured universe domain is not valid. + """ + + # NOTE (b/349488459): universe validation is disabled until further notice. + return True + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used by the client instance. + """ + return self._universe_domain + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[ + str, + PredictionServiceTransport, + Callable[..., PredictionServiceTransport], + ] + ] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the prediction service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,PredictionServiceTransport,Callable[..., PredictionServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the PredictionServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that the ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client_options = client_options + if isinstance(self._client_options, dict): + self._client_options = client_options_lib.from_dict(self._client_options) + if self._client_options is None: + self._client_options = client_options_lib.ClientOptions() + self._client_options = cast( + client_options_lib.ClientOptions, self._client_options + ) + + universe_domain_opt = getattr(self._client_options, "universe_domain", None) + + ( + self._use_client_cert, + self._use_mtls_endpoint, + self._universe_domain_env, + ) = PredictionServiceClient._read_environment_variables() + self._client_cert_source = PredictionServiceClient._get_client_cert_source( + self._client_options.client_cert_source, self._use_client_cert + ) + self._universe_domain = PredictionServiceClient._get_universe_domain( + universe_domain_opt, self._universe_domain_env + ) + self._api_endpoint = None # updated below, depending on `transport` + + # Initialize the universe domain validation. + self._is_universe_domain_valid = False + + if CLIENT_LOGGING_SUPPORTED: # pragma: NO COVER + # Setup logging. + client_logging.initialize_logging() + + api_key_value = getattr(self._client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + transport_provided = isinstance(transport, PredictionServiceTransport) + if transport_provided: + # transport is a PredictionServiceTransport instance. + if credentials or self._client_options.credentials_file or api_key_value: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if self._client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = cast(PredictionServiceTransport, transport) + self._api_endpoint = self._transport.host + + self._api_endpoint = ( + self._api_endpoint + or PredictionServiceClient._get_api_endpoint( + self._client_options.api_endpoint, + self._client_cert_source, + self._universe_domain, + self._use_mtls_endpoint, + ) + ) + + if not transport_provided: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + + transport_init: Union[ + Type[PredictionServiceTransport], + Callable[..., PredictionServiceTransport], + ] = ( + PredictionServiceClient.get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., PredictionServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( + credentials=credentials, + credentials_file=self._client_options.credentials_file, + host=self._api_endpoint, + scopes=self._client_options.scopes, + client_cert_source_for_mtls=self._client_cert_source, + quota_project_id=self._client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=self._client_options.api_audience, + ) + + if "async" not in str(self._transport): + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.PredictionServiceClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PredictionService", + "universeDomain": getattr( + self._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._transport._credentials).__module__}.{type(self._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.PredictionService", + "credentialsType": None, + }, + ) + + def predict( + self, + request: Optional[Union[prediction_service.PredictRequest, dict]] = None, + *, + model: Optional[str] = None, + instances: Optional[MutableSequence[struct_pb2.Value]] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> prediction_service.PredictResponse: + r"""Performs a prediction request. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_predict(): + # Create a client + client = generativelanguage_v1alpha.PredictionServiceClient() + + # Initialize request argument(s) + instances = generativelanguage_v1alpha.Value() + instances.null_value = "NULL_VALUE" + + request = generativelanguage_v1alpha.PredictRequest( + model="model_value", + instances=instances, + ) + + # Make the request + response = client.predict(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.PredictRequest, dict]): + The request object. Request message for + [PredictionService.Predict][google.ai.generativelanguage.v1alpha.PredictionService.Predict]. + model (str): + Required. The name of the model for prediction. Format: + ``name=models/{model}``. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + instances (MutableSequence[google.protobuf.struct_pb2.Value]): + Required. The instances that are the + input to the prediction call. + + This corresponds to the ``instances`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.PredictResponse: + Response message for [PredictionService.Predict]. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, instances]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, prediction_service.PredictRequest): + request = prediction_service.PredictRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if instances is not None: + request.instances.extend(instances) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.predict] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def __enter__(self) -> "PredictionServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("PredictionServiceClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/README.rst b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/README.rst new file mode 100644 index 000000000000..504aaca0a144 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/README.rst @@ -0,0 +1,9 @@ + +transport inheritance structure +_______________________________ + +`PredictionServiceTransport` is the ABC for all transports. +- public child `PredictionServiceGrpcTransport` for sync gRPC transport (defined in `grpc.py`). +- public child `PredictionServiceGrpcAsyncIOTransport` for async gRPC transport (defined in `grpc_asyncio.py`). +- private child `_BasePredictionServiceRestTransport` for base REST transport with inner classes `_BaseMETHOD` (defined in `rest_base.py`). +- public child `PredictionServiceRestTransport` for sync REST transport with inner classes `METHOD` derived from the parent's corresponding `_BaseMETHOD` classes (defined in `rest.py`). diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/__init__.py new file mode 100644 index 000000000000..d6d645ba1ff1 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/__init__.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import PredictionServiceTransport +from .grpc import PredictionServiceGrpcTransport +from .grpc_asyncio import PredictionServiceGrpcAsyncIOTransport +from .rest import PredictionServiceRestInterceptor, PredictionServiceRestTransport + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[PredictionServiceTransport]] +_transport_registry["grpc"] = PredictionServiceGrpcTransport +_transport_registry["grpc_asyncio"] = PredictionServiceGrpcAsyncIOTransport +_transport_registry["rest"] = PredictionServiceRestTransport + +__all__ = ( + "PredictionServiceTransport", + "PredictionServiceGrpcTransport", + "PredictionServiceGrpcAsyncIOTransport", + "PredictionServiceRestTransport", + "PredictionServiceRestInterceptor", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/base.py new file mode 100644 index 000000000000..e3b9053178c3 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/base.py @@ -0,0 +1,196 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version +from google.ai.generativelanguage_v1alpha.types import prediction_service + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +class PredictionServiceTransport(abc.ABC): + """Abstract transport class for PredictionService.""" + + AUTH_SCOPES = () + + DEFAULT_HOST: str = "generativelanguage.googleapis.com" + + def __init__( + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + if not hasattr(self, "_ignore_credentials"): + self._ignore_credentials: bool = False + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) + elif credentials is None and not self._ignore_credentials: + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + @property + def host(self): + return self._host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.predict: gapic_v1.method.wrap_method( + self.predict, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: gapic_v1.method.wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: gapic_v1.method.wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def predict( + self, + ) -> Callable[ + [prediction_service.PredictRequest], + Union[ + prediction_service.PredictResponse, + Awaitable[prediction_service.PredictResponse], + ], + ]: + raise NotImplementedError() + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], + Union[ + operations_pb2.ListOperationsResponse, + Awaitable[operations_pb2.ListOperationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_operation( + self, + ) -> Callable[ + [operations_pb2.GetOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ("PredictionServiceTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/grpc.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/grpc.py new file mode 100644 index 000000000000..dd72b9c0b00e --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/grpc.py @@ -0,0 +1,396 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import logging as std_logging +import pickle +from typing import Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import gapic_v1, grpc_helpers +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import prediction_service + +from .base import DEFAULT_CLIENT_INFO, PredictionServiceTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientInterceptor(grpc.UnaryUnaryClientInterceptor): # pragma: NO COVER + def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PredictionService", + "rpcName": client_call_details.method, + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + + response = continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = response.result() + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response for {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PredictionService", + "rpcName": client_call_details.method, + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class PredictionServiceGrpcTransport(PredictionServiceTransport): + """gRPC backend transport for PredictionService. + + A service for online predictions and explanations. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, grpc.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientInterceptor() + self._logged_channel = grpc.intercept_channel( + self._grpc_channel, self._interceptor + ) + + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service.""" + return self._grpc_channel + + @property + def predict( + self, + ) -> Callable[ + [prediction_service.PredictRequest], prediction_service.PredictResponse + ]: + r"""Return a callable for the predict method over gRPC. + + Performs a prediction request. + + Returns: + Callable[[~.PredictRequest], + ~.PredictResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "predict" not in self._stubs: + self._stubs["predict"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.PredictionService/Predict", + request_serializer=prediction_service.PredictRequest.serialize, + response_deserializer=prediction_service.PredictResponse.deserialize, + ) + return self._stubs["predict"] + + def close(self): + self._logged_channel.close() + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ("PredictionServiceGrpcTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/grpc_asyncio.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..cbab29697995 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/grpc_asyncio.py @@ -0,0 +1,429 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import json +import logging as std_logging +import pickle +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, grpc_helpers_async +from google.api_core import retry_async as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +from grpc.experimental import aio # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import prediction_service + +from .base import DEFAULT_CLIENT_INFO, PredictionServiceTransport +from .grpc import PredictionServiceGrpcTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientAIOInterceptor( + grpc.aio.UnaryUnaryClientInterceptor +): # pragma: NO COVER + async def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PredictionService", + "rpcName": str(client_call_details.method), + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + response = await continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = await response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = await response + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response to rpc {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PredictionService", + "rpcName": str(client_call_details.method), + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class PredictionServiceGrpcAsyncIOTransport(PredictionServiceTransport): + """gRPC AsyncIO backend transport for PredictionService. + + A service for online predictions and explanations. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, aio.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientAIOInterceptor() + self._grpc_channel._unary_unary_interceptors.append(self._interceptor) + self._logged_channel = self._grpc_channel + self._wrap_with_kind = ( + "kind" in inspect.signature(gapic_v1.method_async.wrap_method).parameters + ) + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def predict( + self, + ) -> Callable[ + [prediction_service.PredictRequest], + Awaitable[prediction_service.PredictResponse], + ]: + r"""Return a callable for the predict method over gRPC. + + Performs a prediction request. + + Returns: + Callable[[~.PredictRequest], + Awaitable[~.PredictResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "predict" not in self._stubs: + self._stubs["predict"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.PredictionService/Predict", + request_serializer=prediction_service.PredictRequest.serialize, + response_deserializer=prediction_service.PredictResponse.deserialize, + ) + return self._stubs["predict"] + + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.predict: self._wrap_method( + self.predict, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: self._wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: self._wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def _wrap_method(self, func, *args, **kwargs): + if self._wrap_with_kind: # pragma: NO COVER + kwargs["kind"] = self.kind + return gapic_v1.method_async.wrap_method(func, *args, **kwargs) + + def close(self): + return self._logged_channel.close() + + @property + def kind(self) -> str: + return "grpc_asyncio" + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + +__all__ = ("PredictionServiceGrpcAsyncIOTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/rest.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/rest.py new file mode 100644 index 000000000000..7dbce1753a9b --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/rest.py @@ -0,0 +1,700 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +import json # type: ignore +import logging +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, rest_helpers, rest_streaming +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.requests import AuthorizedSession # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import json_format +from requests import __version__ as requests_version + +from google.ai.generativelanguage_v1alpha.types import prediction_service + +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from .rest_base import _BasePredictionServiceRestTransport + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = logging.getLogger(__name__) + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=f"requests@{requests_version}", +) + + +class PredictionServiceRestInterceptor: + """Interceptor for PredictionService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the PredictionServiceRestTransport. + + .. code-block:: python + class MyCustomPredictionServiceInterceptor(PredictionServiceRestInterceptor): + def pre_predict(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_predict(self, response): + logging.log(f"Received response: {response}") + return response + + transport = PredictionServiceRestTransport(interceptor=MyCustomPredictionServiceInterceptor()) + client = PredictionServiceClient(transport=transport) + + + """ + + def pre_predict( + self, + request: prediction_service.PredictRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + prediction_service.PredictRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for predict + + Override in a subclass to manipulate the request or metadata + before they are sent to the PredictionService server. + """ + return request, metadata + + def post_predict( + self, response: prediction_service.PredictResponse + ) -> prediction_service.PredictResponse: + """Post-rpc interceptor for predict + + Override in a subclass to manipulate the response + after it is returned by the PredictionService server but before + it is returned to user code. + """ + return response + + def pre_get_operation( + self, + request: operations_pb2.GetOperationRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.GetOperationRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for get_operation + + Override in a subclass to manipulate the request or metadata + before they are sent to the PredictionService server. + """ + return request, metadata + + def post_get_operation( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for get_operation + + Override in a subclass to manipulate the response + after it is returned by the PredictionService server but before + it is returned to user code. + """ + return response + + def pre_list_operations( + self, + request: operations_pb2.ListOperationsRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.ListOperationsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for list_operations + + Override in a subclass to manipulate the request or metadata + before they are sent to the PredictionService server. + """ + return request, metadata + + def post_list_operations( + self, response: operations_pb2.ListOperationsResponse + ) -> operations_pb2.ListOperationsResponse: + """Post-rpc interceptor for list_operations + + Override in a subclass to manipulate the response + after it is returned by the PredictionService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class PredictionServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: PredictionServiceRestInterceptor + + +class PredictionServiceRestTransport(_BasePredictionServiceRestTransport): + """REST backend synchronous transport for PredictionService. + + A service for online predictions and explanations. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[PredictionServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + url_scheme=url_scheme, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or PredictionServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _Predict( + _BasePredictionServiceRestTransport._BasePredict, PredictionServiceRestStub + ): + def __hash__(self): + return hash("PredictionServiceRestTransport.Predict") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: prediction_service.PredictRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> prediction_service.PredictResponse: + r"""Call the predict method over HTTP. + + Args: + request (~.prediction_service.PredictRequest): + The request object. Request message for + [PredictionService.Predict][google.ai.generativelanguage.v1alpha.PredictionService.Predict]. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.prediction_service.PredictResponse: + Response message for [PredictionService.Predict]. + """ + + http_options = ( + _BasePredictionServiceRestTransport._BasePredict._get_http_options() + ) + + request, metadata = self._interceptor.pre_predict(request, metadata) + transcoded_request = _BasePredictionServiceRestTransport._BasePredict._get_transcoded_request( + http_options, request + ) + + body = ( + _BasePredictionServiceRestTransport._BasePredict._get_request_body_json( + transcoded_request + ) + ) + + # Jsonify the query params + query_params = ( + _BasePredictionServiceRestTransport._BasePredict._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.PredictionServiceClient.Predict", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PredictionService", + "rpcName": "Predict", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = PredictionServiceRestTransport._Predict._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = prediction_service.PredictResponse() + pb_resp = prediction_service.PredictResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_predict(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = prediction_service.PredictResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.PredictionServiceClient.predict", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PredictionService", + "rpcName": "Predict", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + @property + def predict( + self, + ) -> Callable[ + [prediction_service.PredictRequest], prediction_service.PredictResponse + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._Predict(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_operation(self): + return self._GetOperation(self._session, self._host, self._interceptor) # type: ignore + + class _GetOperation( + _BasePredictionServiceRestTransport._BaseGetOperation, PredictionServiceRestStub + ): + def __hash__(self): + return hash("PredictionServiceRestTransport.GetOperation") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.GetOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Call the get operation method over HTTP. + + Args: + request (operations_pb2.GetOperationRequest): + The request object for GetOperation method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.Operation: Response from GetOperation method. + """ + + http_options = ( + _BasePredictionServiceRestTransport._BaseGetOperation._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_operation(request, metadata) + transcoded_request = _BasePredictionServiceRestTransport._BaseGetOperation._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BasePredictionServiceRestTransport._BaseGetOperation._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.PredictionServiceClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PredictionService", + "rpcName": "GetOperation", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = PredictionServiceRestTransport._GetOperation._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.Operation() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_get_operation(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.PredictionServiceAsyncClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PredictionService", + "rpcName": "GetOperation", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def list_operations(self): + return self._ListOperations(self._session, self._host, self._interceptor) # type: ignore + + class _ListOperations( + _BasePredictionServiceRestTransport._BaseListOperations, + PredictionServiceRestStub, + ): + def __hash__(self): + return hash("PredictionServiceRestTransport.ListOperations") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.ListOperationsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Call the list operations method over HTTP. + + Args: + request (operations_pb2.ListOperationsRequest): + The request object for ListOperations method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.ListOperationsResponse: Response from ListOperations method. + """ + + http_options = ( + _BasePredictionServiceRestTransport._BaseListOperations._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_operations(request, metadata) + transcoded_request = _BasePredictionServiceRestTransport._BaseListOperations._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BasePredictionServiceRestTransport._BaseListOperations._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.PredictionServiceClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PredictionService", + "rpcName": "ListOperations", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = PredictionServiceRestTransport._ListOperations._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.ListOperationsResponse() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_list_operations(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.PredictionServiceAsyncClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.PredictionService", + "rpcName": "ListOperations", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("PredictionServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/rest_base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/rest_base.py new file mode 100644 index 000000000000..8fb2d91436b3 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/prediction_service/transports/rest_base.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json # type: ignore +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1, path_template +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import json_format + +from google.ai.generativelanguage_v1alpha.types import prediction_service + +from .base import DEFAULT_CLIENT_INFO, PredictionServiceTransport + + +class _BasePredictionServiceRestTransport(PredictionServiceTransport): + """Base REST backend transport for PredictionService. + + Note: This class is not meant to be used directly. Use its sync and + async sub-classes instead. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[Any] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[Any]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + class _BasePredict: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{model=models/*}:predict", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = prediction_service.PredictRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BasePredictionServiceRestTransport._BasePredict._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetOperation: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=generatedFiles/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*/operations/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + class _BaseListOperations: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*}/operations", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*}/operations", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + +__all__ = ("_BasePredictionServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/__init__.py new file mode 100644 index 000000000000..8c3e3ba77075 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .async_client import RetrieverServiceAsyncClient +from .client import RetrieverServiceClient + +__all__ = ( + "RetrieverServiceClient", + "RetrieverServiceAsyncClient", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/async_client.py new file mode 100644 index 000000000000..e04f95549d7b --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/async_client.py @@ -0,0 +1,2499 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry_async as retries +from google.api_core.client_options import ClientOptions +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore + +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.retriever_service import pagers +from google.ai.generativelanguage_v1alpha.types import retriever, retriever_service + +from .client import RetrieverServiceClient +from .transports.base import DEFAULT_CLIENT_INFO, RetrieverServiceTransport +from .transports.grpc_asyncio import RetrieverServiceGrpcAsyncIOTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class RetrieverServiceAsyncClient: + """An API for semantic search over a corpus of user uploaded + content. + """ + + _client: RetrieverServiceClient + + # Copy defaults from the synchronous client for use here. + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = RetrieverServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = RetrieverServiceClient.DEFAULT_MTLS_ENDPOINT + _DEFAULT_ENDPOINT_TEMPLATE = RetrieverServiceClient._DEFAULT_ENDPOINT_TEMPLATE + _DEFAULT_UNIVERSE = RetrieverServiceClient._DEFAULT_UNIVERSE + + chunk_path = staticmethod(RetrieverServiceClient.chunk_path) + parse_chunk_path = staticmethod(RetrieverServiceClient.parse_chunk_path) + corpus_path = staticmethod(RetrieverServiceClient.corpus_path) + parse_corpus_path = staticmethod(RetrieverServiceClient.parse_corpus_path) + document_path = staticmethod(RetrieverServiceClient.document_path) + parse_document_path = staticmethod(RetrieverServiceClient.parse_document_path) + common_billing_account_path = staticmethod( + RetrieverServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + RetrieverServiceClient.parse_common_billing_account_path + ) + common_folder_path = staticmethod(RetrieverServiceClient.common_folder_path) + parse_common_folder_path = staticmethod( + RetrieverServiceClient.parse_common_folder_path + ) + common_organization_path = staticmethod( + RetrieverServiceClient.common_organization_path + ) + parse_common_organization_path = staticmethod( + RetrieverServiceClient.parse_common_organization_path + ) + common_project_path = staticmethod(RetrieverServiceClient.common_project_path) + parse_common_project_path = staticmethod( + RetrieverServiceClient.parse_common_project_path + ) + common_location_path = staticmethod(RetrieverServiceClient.common_location_path) + parse_common_location_path = staticmethod( + RetrieverServiceClient.parse_common_location_path + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + RetrieverServiceAsyncClient: The constructed client. + """ + return RetrieverServiceClient.from_service_account_info.__func__(RetrieverServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + RetrieverServiceAsyncClient: The constructed client. + """ + return RetrieverServiceClient.from_service_account_file.__func__(RetrieverServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return RetrieverServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> RetrieverServiceTransport: + """Returns the transport used by the client instance. + + Returns: + RetrieverServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._client._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used + by the client instance. + """ + return self._client._universe_domain + + get_transport_class = RetrieverServiceClient.get_transport_class + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[ + str, RetrieverServiceTransport, Callable[..., RetrieverServiceTransport] + ] + ] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the retriever service async client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,RetrieverServiceTransport,Callable[..., RetrieverServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the RetrieverServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = RetrieverServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "universeDomain": getattr( + self._client._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._client._transport._credentials).__module__}.{type(self._client._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._client._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "credentialsType": None, + }, + ) + + async def create_corpus( + self, + request: Optional[Union[retriever_service.CreateCorpusRequest, dict]] = None, + *, + corpus: Optional[retriever.Corpus] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Corpus: + r"""Creates an empty ``Corpus``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_create_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateCorpusRequest( + ) + + # Make the request + response = await client.create_corpus(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.CreateCorpusRequest, dict]]): + The request object. Request to create a ``Corpus``. + corpus (:class:`google.ai.generativelanguage_v1alpha.types.Corpus`): + Required. The ``Corpus`` to create. + This corresponds to the ``corpus`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Corpus: + A Corpus is a collection of Documents. + A project can create up to 5 corpora. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([corpus]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.CreateCorpusRequest): + request = retriever_service.CreateCorpusRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if corpus is not None: + request.corpus = corpus + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_corpus + ] + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_corpus( + self, + request: Optional[Union[retriever_service.GetCorpusRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Corpus: + r"""Gets information about a specific ``Corpus``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_get_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetCorpusRequest( + name="name_value", + ) + + # Make the request + response = await client.get_corpus(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.GetCorpusRequest, dict]]): + The request object. Request for getting information about a specific + ``Corpus``. + name (:class:`str`): + Required. The name of the ``Corpus``. Example: + ``corpora/my-corpus-123`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Corpus: + A Corpus is a collection of Documents. + A project can create up to 5 corpora. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.GetCorpusRequest): + request = retriever_service.GetCorpusRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_corpus + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_corpus( + self, + request: Optional[Union[retriever_service.UpdateCorpusRequest, dict]] = None, + *, + corpus: Optional[retriever.Corpus] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Corpus: + r"""Updates a ``Corpus``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_update_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateCorpusRequest( + ) + + # Make the request + response = await client.update_corpus(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.UpdateCorpusRequest, dict]]): + The request object. Request to update a ``Corpus``. + corpus (:class:`google.ai.generativelanguage_v1alpha.types.Corpus`): + Required. The ``Corpus`` to update. + This corresponds to the ``corpus`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The list of fields to update. Currently, this + only supports updating ``display_name``. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Corpus: + A Corpus is a collection of Documents. + A project can create up to 5 corpora. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([corpus, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.UpdateCorpusRequest): + request = retriever_service.UpdateCorpusRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if corpus is not None: + request.corpus = corpus + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_corpus + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("corpus.name", request.corpus.name),) + ), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_corpus( + self, + request: Optional[Union[retriever_service.DeleteCorpusRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Deletes a ``Corpus``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_delete_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteCorpusRequest( + name="name_value", + ) + + # Make the request + await client.delete_corpus(request=request) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.DeleteCorpusRequest, dict]]): + The request object. Request to delete a ``Corpus``. + name (:class:`str`): + Required. The resource name of the ``Corpus``. Example: + ``corpora/my-corpus-123`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.DeleteCorpusRequest): + request = retriever_service.DeleteCorpusRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_corpus + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def list_corpora( + self, + request: Optional[Union[retriever_service.ListCorporaRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListCorporaAsyncPager: + r"""Lists all ``Corpora`` owned by the user. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_list_corpora(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListCorporaRequest( + ) + + # Make the request + page_result = client.list_corpora(request=request) + + # Handle the response + async for response in page_result: + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.ListCorporaRequest, dict]]): + The request object. Request for listing ``Corpora``. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.retriever_service.pagers.ListCorporaAsyncPager: + Response from ListCorpora containing a paginated list of Corpora. + The results are sorted by ascending + corpus.create_time. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.ListCorporaRequest): + request = retriever_service.ListCorporaRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_corpora + ] + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListCorporaAsyncPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def query_corpus( + self, + request: Optional[Union[retriever_service.QueryCorpusRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever_service.QueryCorpusResponse: + r"""Performs semantic search over a ``Corpus``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_query_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.QueryCorpusRequest( + name="name_value", + query="query_value", + ) + + # Make the request + response = await client.query_corpus(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.QueryCorpusRequest, dict]]): + The request object. Request for querying a ``Corpus``. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.QueryCorpusResponse: + Response from QueryCorpus containing a list of relevant + chunks. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.QueryCorpusRequest): + request = retriever_service.QueryCorpusRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.query_corpus + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def create_document( + self, + request: Optional[Union[retriever_service.CreateDocumentRequest, dict]] = None, + *, + parent: Optional[str] = None, + document: Optional[retriever.Document] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Document: + r"""Creates an empty ``Document``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_create_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateDocumentRequest( + parent="parent_value", + ) + + # Make the request + response = await client.create_document(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.CreateDocumentRequest, dict]]): + The request object. Request to create a ``Document``. + parent (:class:`str`): + Required. The name of the ``Corpus`` where this + ``Document`` will be created. Example: + ``corpora/my-corpus-123`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + document (:class:`google.ai.generativelanguage_v1alpha.types.Document`): + Required. The ``Document`` to create. + This corresponds to the ``document`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Document: + A Document is a collection of Chunks. + A Corpus can have a maximum of 10,000 Documents. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, document]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.CreateDocumentRequest): + request = retriever_service.CreateDocumentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + if document is not None: + request.document = document + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_document + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_document( + self, + request: Optional[Union[retriever_service.GetDocumentRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Document: + r"""Gets information about a specific ``Document``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_get_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetDocumentRequest( + name="name_value", + ) + + # Make the request + response = await client.get_document(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.GetDocumentRequest, dict]]): + The request object. Request for getting information about a specific + ``Document``. + name (:class:`str`): + Required. The name of the ``Document`` to retrieve. + Example: ``corpora/my-corpus-123/documents/the-doc-abc`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Document: + A Document is a collection of Chunks. + A Corpus can have a maximum of 10,000 Documents. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.GetDocumentRequest): + request = retriever_service.GetDocumentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_document + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_document( + self, + request: Optional[Union[retriever_service.UpdateDocumentRequest, dict]] = None, + *, + document: Optional[retriever.Document] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Document: + r"""Updates a ``Document``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_update_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateDocumentRequest( + ) + + # Make the request + response = await client.update_document(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.UpdateDocumentRequest, dict]]): + The request object. Request to update a ``Document``. + document (:class:`google.ai.generativelanguage_v1alpha.types.Document`): + Required. The ``Document`` to update. + This corresponds to the ``document`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The list of fields to update. Currently, this + only supports updating ``display_name`` and + ``custom_metadata``. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Document: + A Document is a collection of Chunks. + A Corpus can have a maximum of 10,000 Documents. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([document, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.UpdateDocumentRequest): + request = retriever_service.UpdateDocumentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if document is not None: + request.document = document + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_document + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("document.name", request.document.name),) + ), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_document( + self, + request: Optional[Union[retriever_service.DeleteDocumentRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Deletes a ``Document``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_delete_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteDocumentRequest( + name="name_value", + ) + + # Make the request + await client.delete_document(request=request) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.DeleteDocumentRequest, dict]]): + The request object. Request to delete a ``Document``. + name (:class:`str`): + Required. The resource name of the ``Document`` to + delete. Example: + ``corpora/my-corpus-123/documents/the-doc-abc`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.DeleteDocumentRequest): + request = retriever_service.DeleteDocumentRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_document + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def list_documents( + self, + request: Optional[Union[retriever_service.ListDocumentsRequest, dict]] = None, + *, + parent: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListDocumentsAsyncPager: + r"""Lists all ``Document``\ s in a ``Corpus``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_list_documents(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListDocumentsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_documents(request=request) + + # Handle the response + async for response in page_result: + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.ListDocumentsRequest, dict]]): + The request object. Request for listing ``Document``\ s. + parent (:class:`str`): + Required. The name of the ``Corpus`` containing + ``Document``\ s. Example: ``corpora/my-corpus-123`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.retriever_service.pagers.ListDocumentsAsyncPager: + Response from ListDocuments containing a paginated list of Documents. + The Documents are sorted by ascending + document.create_time. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.ListDocumentsRequest): + request = retriever_service.ListDocumentsRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_documents + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListDocumentsAsyncPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def query_document( + self, + request: Optional[Union[retriever_service.QueryDocumentRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever_service.QueryDocumentResponse: + r"""Performs semantic search over a ``Document``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_query_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.QueryDocumentRequest( + name="name_value", + query="query_value", + ) + + # Make the request + response = await client.query_document(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.QueryDocumentRequest, dict]]): + The request object. Request for querying a ``Document``. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.QueryDocumentResponse: + Response from QueryDocument containing a list of + relevant chunks. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.QueryDocumentRequest): + request = retriever_service.QueryDocumentRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.query_document + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def create_chunk( + self, + request: Optional[Union[retriever_service.CreateChunkRequest, dict]] = None, + *, + parent: Optional[str] = None, + chunk: Optional[retriever.Chunk] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Chunk: + r"""Creates a ``Chunk``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_create_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + chunk = generativelanguage_v1alpha.Chunk() + chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.CreateChunkRequest( + parent="parent_value", + chunk=chunk, + ) + + # Make the request + response = await client.create_chunk(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.CreateChunkRequest, dict]]): + The request object. Request to create a ``Chunk``. + parent (:class:`str`): + Required. The name of the ``Document`` where this + ``Chunk`` will be created. Example: + ``corpora/my-corpus-123/documents/the-doc-abc`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + chunk (:class:`google.ai.generativelanguage_v1alpha.types.Chunk`): + Required. The ``Chunk`` to create. + This corresponds to the ``chunk`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Chunk: + A Chunk is a subpart of a Document that is treated as an independent unit + for the purposes of vector representation and + storage. A Corpus can have a maximum of 1 million + Chunks. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, chunk]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.CreateChunkRequest): + request = retriever_service.CreateChunkRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + if chunk is not None: + request.chunk = chunk + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.create_chunk + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def batch_create_chunks( + self, + request: Optional[ + Union[retriever_service.BatchCreateChunksRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever_service.BatchCreateChunksResponse: + r"""Batch create ``Chunk``\ s. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_batch_create_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.CreateChunkRequest() + requests.parent = "parent_value" + requests.chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.BatchCreateChunksRequest( + requests=requests, + ) + + # Make the request + response = await client.batch_create_chunks(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.BatchCreateChunksRequest, dict]]): + The request object. Request to batch create ``Chunk``\ s. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.BatchCreateChunksResponse: + Response from BatchCreateChunks containing a list of + created Chunks. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.BatchCreateChunksRequest): + request = retriever_service.BatchCreateChunksRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_create_chunks + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_chunk( + self, + request: Optional[Union[retriever_service.GetChunkRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Chunk: + r"""Gets information about a specific ``Chunk``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_get_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetChunkRequest( + name="name_value", + ) + + # Make the request + response = await client.get_chunk(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.GetChunkRequest, dict]]): + The request object. Request for getting information about a specific + ``Chunk``. + name (:class:`str`): + Required. The name of the ``Chunk`` to retrieve. + Example: + ``corpora/my-corpus-123/documents/the-doc-abc/chunks/some-chunk`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Chunk: + A Chunk is a subpart of a Document that is treated as an independent unit + for the purposes of vector representation and + storage. A Corpus can have a maximum of 1 million + Chunks. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.GetChunkRequest): + request = retriever_service.GetChunkRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.get_chunk + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def update_chunk( + self, + request: Optional[Union[retriever_service.UpdateChunkRequest, dict]] = None, + *, + chunk: Optional[retriever.Chunk] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Chunk: + r"""Updates a ``Chunk``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_update_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + chunk = generativelanguage_v1alpha.Chunk() + chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.UpdateChunkRequest( + chunk=chunk, + ) + + # Make the request + response = await client.update_chunk(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.UpdateChunkRequest, dict]]): + The request object. Request to update a ``Chunk``. + chunk (:class:`google.ai.generativelanguage_v1alpha.types.Chunk`): + Required. The ``Chunk`` to update. + This corresponds to the ``chunk`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (:class:`google.protobuf.field_mask_pb2.FieldMask`): + Required. The list of fields to update. Currently, this + only supports updating ``custom_metadata`` and ``data``. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Chunk: + A Chunk is a subpart of a Document that is treated as an independent unit + for the purposes of vector representation and + storage. A Corpus can have a maximum of 1 million + Chunks. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([chunk, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.UpdateChunkRequest): + request = retriever_service.UpdateChunkRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if chunk is not None: + request.chunk = chunk + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.update_chunk + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("chunk.name", request.chunk.name),) + ), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def batch_update_chunks( + self, + request: Optional[ + Union[retriever_service.BatchUpdateChunksRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever_service.BatchUpdateChunksResponse: + r"""Batch update ``Chunk``\ s. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_batch_update_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.UpdateChunkRequest() + requests.chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.BatchUpdateChunksRequest( + requests=requests, + ) + + # Make the request + response = await client.batch_update_chunks(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.BatchUpdateChunksRequest, dict]]): + The request object. Request to batch update ``Chunk``\ s. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.BatchUpdateChunksResponse: + Response from BatchUpdateChunks containing a list of + updated Chunks. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.BatchUpdateChunksRequest): + request = retriever_service.BatchUpdateChunksRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_update_chunks + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def delete_chunk( + self, + request: Optional[Union[retriever_service.DeleteChunkRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Deletes a ``Chunk``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_delete_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteChunkRequest( + name="name_value", + ) + + # Make the request + await client.delete_chunk(request=request) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.DeleteChunkRequest, dict]]): + The request object. Request to delete a ``Chunk``. + name (:class:`str`): + Required. The resource name of the ``Chunk`` to delete. + Example: + ``corpora/my-corpus-123/documents/the-doc-abc/chunks/some-chunk`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.DeleteChunkRequest): + request = retriever_service.DeleteChunkRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.delete_chunk + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def batch_delete_chunks( + self, + request: Optional[ + Union[retriever_service.BatchDeleteChunksRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Batch delete ``Chunk``\ s. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_batch_delete_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.DeleteChunkRequest() + requests.name = "name_value" + + request = generativelanguage_v1alpha.BatchDeleteChunksRequest( + requests=requests, + ) + + # Make the request + await client.batch_delete_chunks(request=request) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.BatchDeleteChunksRequest, dict]]): + The request object. Request to batch delete ``Chunk``\ s. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.BatchDeleteChunksRequest): + request = retriever_service.BatchDeleteChunksRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_delete_chunks + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + async def list_chunks( + self, + request: Optional[Union[retriever_service.ListChunksRequest, dict]] = None, + *, + parent: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListChunksAsyncPager: + r"""Lists all ``Chunk``\ s in a ``Document``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_list_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListChunksRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_chunks(request=request) + + # Handle the response + async for response in page_result: + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.ListChunksRequest, dict]]): + The request object. Request for listing ``Chunk``\ s. + parent (:class:`str`): + Required. The name of the ``Document`` containing + ``Chunk``\ s. Example: + ``corpora/my-corpus-123/documents/the-doc-abc`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.retriever_service.pagers.ListChunksAsyncPager: + Response from ListChunks containing a paginated list of Chunks. + The Chunks are sorted by ascending chunk.create_time. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.ListChunksRequest): + request = retriever_service.ListChunksRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.list_chunks + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__aiter__` convenience method. + response = pagers.ListChunksAsyncPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "RetrieverServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("RetrieverServiceAsyncClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/client.py new file mode 100644 index 000000000000..bfc549edba33 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/client.py @@ -0,0 +1,2875 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import os +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) +import warnings + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.retriever_service import pagers +from google.ai.generativelanguage_v1alpha.types import retriever, retriever_service + +from .transports.base import DEFAULT_CLIENT_INFO, RetrieverServiceTransport +from .transports.grpc import RetrieverServiceGrpcTransport +from .transports.grpc_asyncio import RetrieverServiceGrpcAsyncIOTransport +from .transports.rest import RetrieverServiceRestTransport + + +class RetrieverServiceClientMeta(type): + """Metaclass for the RetrieverService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = ( + OrderedDict() + ) # type: Dict[str, Type[RetrieverServiceTransport]] + _transport_registry["grpc"] = RetrieverServiceGrpcTransport + _transport_registry["grpc_asyncio"] = RetrieverServiceGrpcAsyncIOTransport + _transport_registry["rest"] = RetrieverServiceRestTransport + + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[RetrieverServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class RetrieverServiceClient(metaclass=RetrieverServiceClientMeta): + """An API for semantic search over a corpus of user uploaded + content. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + _DEFAULT_ENDPOINT_TEMPLATE = "generativelanguage.{UNIVERSE_DOMAIN}" + _DEFAULT_UNIVERSE = "googleapis.com" + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + RetrieverServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + RetrieverServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> RetrieverServiceTransport: + """Returns the transport used by the client instance. + + Returns: + RetrieverServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def chunk_path( + corpus: str, + document: str, + chunk: str, + ) -> str: + """Returns a fully-qualified chunk string.""" + return "corpora/{corpus}/documents/{document}/chunks/{chunk}".format( + corpus=corpus, + document=document, + chunk=chunk, + ) + + @staticmethod + def parse_chunk_path(path: str) -> Dict[str, str]: + """Parses a chunk path into its component segments.""" + m = re.match( + r"^corpora/(?P.+?)/documents/(?P.+?)/chunks/(?P.+?)$", + path, + ) + return m.groupdict() if m else {} + + @staticmethod + def corpus_path( + corpus: str, + ) -> str: + """Returns a fully-qualified corpus string.""" + return "corpora/{corpus}".format( + corpus=corpus, + ) + + @staticmethod + def parse_corpus_path(path: str) -> Dict[str, str]: + """Parses a corpus path into its component segments.""" + m = re.match(r"^corpora/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def document_path( + corpus: str, + document: str, + ) -> str: + """Returns a fully-qualified document string.""" + return "corpora/{corpus}/documents/{document}".format( + corpus=corpus, + document=document, + ) + + @staticmethod + def parse_document_path(path: str) -> Dict[str, str]: + """Parses a document path into its component segments.""" + m = re.match(r"^corpora/(?P.+?)/documents/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path( + billing_account: str, + ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path( + folder: str, + ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format( + folder=folder, + ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path( + organization: str, + ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format( + organization=organization, + ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path( + project: str, + ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format( + project=project, + ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path( + project: str, + location: str, + ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Deprecated. Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + + warnings.warn( + "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.", + DeprecationWarning, + ) + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + @staticmethod + def _read_environment_variables(): + """Returns the environment variables used by the client. + + Returns: + Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE, + GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables. + + Raises: + ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not + any of ["true", "false"]. + google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT + is not any of ["auto", "never", "always"]. + """ + use_client_cert = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() + universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + + @staticmethod + def _get_client_cert_source(provided_cert_source, use_cert_flag): + """Return the client cert source to be used by the client. + + Args: + provided_cert_source (bytes): The client certificate source provided. + use_cert_flag (bool): A flag indicating whether to use the client certificate. + + Returns: + bytes or None: The client cert source to be used by the client. + """ + client_cert_source = None + if use_cert_flag: + if provided_cert_source: + client_cert_source = provided_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + return client_cert_source + + @staticmethod + def _get_api_endpoint( + api_override, client_cert_source, universe_domain, use_mtls_endpoint + ): + """Return the API endpoint used by the client. + + Args: + api_override (str): The API endpoint override. If specified, this is always + the return value of this function and the other arguments are not used. + client_cert_source (bytes): The client certificate source used by the client. + universe_domain (str): The universe domain used by the client. + use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters. + Possible values are "always", "auto", or "never". + + Returns: + str: The API endpoint to be used by the client. + """ + if api_override is not None: + api_endpoint = api_override + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + _default_universe = RetrieverServiceClient._DEFAULT_UNIVERSE + if universe_domain != _default_universe: + raise MutualTLSChannelError( + f"mTLS is not supported in any universe other than {_default_universe}." + ) + api_endpoint = RetrieverServiceClient.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = RetrieverServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=universe_domain + ) + return api_endpoint + + @staticmethod + def _get_universe_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] + ) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = RetrieverServiceClient._DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise ValueError("Universe Domain cannot be an empty string.") + return universe_domain + + def _validate_universe_domain(self): + """Validates client's and credentials' universe domains are consistent. + + Returns: + bool: True iff the configured universe domain is valid. + + Raises: + ValueError: If the configured universe domain is not valid. + """ + + # NOTE (b/349488459): universe validation is disabled until further notice. + return True + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used by the client instance. + """ + return self._universe_domain + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[ + str, RetrieverServiceTransport, Callable[..., RetrieverServiceTransport] + ] + ] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the retriever service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,RetrieverServiceTransport,Callable[..., RetrieverServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the RetrieverServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that the ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client_options = client_options + if isinstance(self._client_options, dict): + self._client_options = client_options_lib.from_dict(self._client_options) + if self._client_options is None: + self._client_options = client_options_lib.ClientOptions() + self._client_options = cast( + client_options_lib.ClientOptions, self._client_options + ) + + universe_domain_opt = getattr(self._client_options, "universe_domain", None) + + ( + self._use_client_cert, + self._use_mtls_endpoint, + self._universe_domain_env, + ) = RetrieverServiceClient._read_environment_variables() + self._client_cert_source = RetrieverServiceClient._get_client_cert_source( + self._client_options.client_cert_source, self._use_client_cert + ) + self._universe_domain = RetrieverServiceClient._get_universe_domain( + universe_domain_opt, self._universe_domain_env + ) + self._api_endpoint = None # updated below, depending on `transport` + + # Initialize the universe domain validation. + self._is_universe_domain_valid = False + + if CLIENT_LOGGING_SUPPORTED: # pragma: NO COVER + # Setup logging. + client_logging.initialize_logging() + + api_key_value = getattr(self._client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + transport_provided = isinstance(transport, RetrieverServiceTransport) + if transport_provided: + # transport is a RetrieverServiceTransport instance. + if credentials or self._client_options.credentials_file or api_key_value: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if self._client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = cast(RetrieverServiceTransport, transport) + self._api_endpoint = self._transport.host + + self._api_endpoint = ( + self._api_endpoint + or RetrieverServiceClient._get_api_endpoint( + self._client_options.api_endpoint, + self._client_cert_source, + self._universe_domain, + self._use_mtls_endpoint, + ) + ) + + if not transport_provided: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + + transport_init: Union[ + Type[RetrieverServiceTransport], + Callable[..., RetrieverServiceTransport], + ] = ( + RetrieverServiceClient.get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., RetrieverServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( + credentials=credentials, + credentials_file=self._client_options.credentials_file, + host=self._api_endpoint, + scopes=self._client_options.scopes, + client_cert_source_for_mtls=self._client_cert_source, + quota_project_id=self._client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=self._client_options.api_audience, + ) + + if "async" not in str(self._transport): + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.RetrieverServiceClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "universeDomain": getattr( + self._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._transport._credentials).__module__}.{type(self._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "credentialsType": None, + }, + ) + + def create_corpus( + self, + request: Optional[Union[retriever_service.CreateCorpusRequest, dict]] = None, + *, + corpus: Optional[retriever.Corpus] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Corpus: + r"""Creates an empty ``Corpus``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_create_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateCorpusRequest( + ) + + # Make the request + response = client.create_corpus(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.CreateCorpusRequest, dict]): + The request object. Request to create a ``Corpus``. + corpus (google.ai.generativelanguage_v1alpha.types.Corpus): + Required. The ``Corpus`` to create. + This corresponds to the ``corpus`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Corpus: + A Corpus is a collection of Documents. + A project can create up to 5 corpora. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([corpus]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.CreateCorpusRequest): + request = retriever_service.CreateCorpusRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if corpus is not None: + request.corpus = corpus + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_corpus] + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_corpus( + self, + request: Optional[Union[retriever_service.GetCorpusRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Corpus: + r"""Gets information about a specific ``Corpus``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_get_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetCorpusRequest( + name="name_value", + ) + + # Make the request + response = client.get_corpus(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.GetCorpusRequest, dict]): + The request object. Request for getting information about a specific + ``Corpus``. + name (str): + Required. The name of the ``Corpus``. Example: + ``corpora/my-corpus-123`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Corpus: + A Corpus is a collection of Documents. + A project can create up to 5 corpora. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.GetCorpusRequest): + request = retriever_service.GetCorpusRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_corpus] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def update_corpus( + self, + request: Optional[Union[retriever_service.UpdateCorpusRequest, dict]] = None, + *, + corpus: Optional[retriever.Corpus] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Corpus: + r"""Updates a ``Corpus``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_update_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateCorpusRequest( + ) + + # Make the request + response = client.update_corpus(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.UpdateCorpusRequest, dict]): + The request object. Request to update a ``Corpus``. + corpus (google.ai.generativelanguage_v1alpha.types.Corpus): + Required. The ``Corpus`` to update. + This corresponds to the ``corpus`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The list of fields to update. Currently, this + only supports updating ``display_name``. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Corpus: + A Corpus is a collection of Documents. + A project can create up to 5 corpora. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([corpus, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.UpdateCorpusRequest): + request = retriever_service.UpdateCorpusRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if corpus is not None: + request.corpus = corpus + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_corpus] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("corpus.name", request.corpus.name),) + ), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_corpus( + self, + request: Optional[Union[retriever_service.DeleteCorpusRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Deletes a ``Corpus``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_delete_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteCorpusRequest( + name="name_value", + ) + + # Make the request + client.delete_corpus(request=request) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.DeleteCorpusRequest, dict]): + The request object. Request to delete a ``Corpus``. + name (str): + Required. The resource name of the ``Corpus``. Example: + ``corpora/my-corpus-123`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.DeleteCorpusRequest): + request = retriever_service.DeleteCorpusRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_corpus] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def list_corpora( + self, + request: Optional[Union[retriever_service.ListCorporaRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListCorporaPager: + r"""Lists all ``Corpora`` owned by the user. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_list_corpora(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListCorporaRequest( + ) + + # Make the request + page_result = client.list_corpora(request=request) + + # Handle the response + for response in page_result: + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.ListCorporaRequest, dict]): + The request object. Request for listing ``Corpora``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.retriever_service.pagers.ListCorporaPager: + Response from ListCorpora containing a paginated list of Corpora. + The results are sorted by ascending + corpus.create_time. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.ListCorporaRequest): + request = retriever_service.ListCorporaRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_corpora] + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListCorporaPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def query_corpus( + self, + request: Optional[Union[retriever_service.QueryCorpusRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever_service.QueryCorpusResponse: + r"""Performs semantic search over a ``Corpus``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_query_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.QueryCorpusRequest( + name="name_value", + query="query_value", + ) + + # Make the request + response = client.query_corpus(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.QueryCorpusRequest, dict]): + The request object. Request for querying a ``Corpus``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.QueryCorpusResponse: + Response from QueryCorpus containing a list of relevant + chunks. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.QueryCorpusRequest): + request = retriever_service.QueryCorpusRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.query_corpus] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def create_document( + self, + request: Optional[Union[retriever_service.CreateDocumentRequest, dict]] = None, + *, + parent: Optional[str] = None, + document: Optional[retriever.Document] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Document: + r"""Creates an empty ``Document``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_create_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateDocumentRequest( + parent="parent_value", + ) + + # Make the request + response = client.create_document(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.CreateDocumentRequest, dict]): + The request object. Request to create a ``Document``. + parent (str): + Required. The name of the ``Corpus`` where this + ``Document`` will be created. Example: + ``corpora/my-corpus-123`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + document (google.ai.generativelanguage_v1alpha.types.Document): + Required. The ``Document`` to create. + This corresponds to the ``document`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Document: + A Document is a collection of Chunks. + A Corpus can have a maximum of 10,000 Documents. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, document]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.CreateDocumentRequest): + request = retriever_service.CreateDocumentRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + if document is not None: + request.document = document + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_document] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_document( + self, + request: Optional[Union[retriever_service.GetDocumentRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Document: + r"""Gets information about a specific ``Document``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_get_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetDocumentRequest( + name="name_value", + ) + + # Make the request + response = client.get_document(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.GetDocumentRequest, dict]): + The request object. Request for getting information about a specific + ``Document``. + name (str): + Required. The name of the ``Document`` to retrieve. + Example: ``corpora/my-corpus-123/documents/the-doc-abc`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Document: + A Document is a collection of Chunks. + A Corpus can have a maximum of 10,000 Documents. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.GetDocumentRequest): + request = retriever_service.GetDocumentRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_document] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def update_document( + self, + request: Optional[Union[retriever_service.UpdateDocumentRequest, dict]] = None, + *, + document: Optional[retriever.Document] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Document: + r"""Updates a ``Document``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_update_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateDocumentRequest( + ) + + # Make the request + response = client.update_document(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.UpdateDocumentRequest, dict]): + The request object. Request to update a ``Document``. + document (google.ai.generativelanguage_v1alpha.types.Document): + Required. The ``Document`` to update. + This corresponds to the ``document`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The list of fields to update. Currently, this + only supports updating ``display_name`` and + ``custom_metadata``. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Document: + A Document is a collection of Chunks. + A Corpus can have a maximum of 10,000 Documents. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([document, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.UpdateDocumentRequest): + request = retriever_service.UpdateDocumentRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if document is not None: + request.document = document + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_document] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("document.name", request.document.name),) + ), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_document( + self, + request: Optional[Union[retriever_service.DeleteDocumentRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Deletes a ``Document``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_delete_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteDocumentRequest( + name="name_value", + ) + + # Make the request + client.delete_document(request=request) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.DeleteDocumentRequest, dict]): + The request object. Request to delete a ``Document``. + name (str): + Required. The resource name of the ``Document`` to + delete. Example: + ``corpora/my-corpus-123/documents/the-doc-abc`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.DeleteDocumentRequest): + request = retriever_service.DeleteDocumentRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_document] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def list_documents( + self, + request: Optional[Union[retriever_service.ListDocumentsRequest, dict]] = None, + *, + parent: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListDocumentsPager: + r"""Lists all ``Document``\ s in a ``Corpus``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_list_documents(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListDocumentsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_documents(request=request) + + # Handle the response + for response in page_result: + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.ListDocumentsRequest, dict]): + The request object. Request for listing ``Document``\ s. + parent (str): + Required. The name of the ``Corpus`` containing + ``Document``\ s. Example: ``corpora/my-corpus-123`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.retriever_service.pagers.ListDocumentsPager: + Response from ListDocuments containing a paginated list of Documents. + The Documents are sorted by ascending + document.create_time. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.ListDocumentsRequest): + request = retriever_service.ListDocumentsRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_documents] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListDocumentsPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def query_document( + self, + request: Optional[Union[retriever_service.QueryDocumentRequest, dict]] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever_service.QueryDocumentResponse: + r"""Performs semantic search over a ``Document``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_query_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.QueryDocumentRequest( + name="name_value", + query="query_value", + ) + + # Make the request + response = client.query_document(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.QueryDocumentRequest, dict]): + The request object. Request for querying a ``Document``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.QueryDocumentResponse: + Response from QueryDocument containing a list of + relevant chunks. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.QueryDocumentRequest): + request = retriever_service.QueryDocumentRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.query_document] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def create_chunk( + self, + request: Optional[Union[retriever_service.CreateChunkRequest, dict]] = None, + *, + parent: Optional[str] = None, + chunk: Optional[retriever.Chunk] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Chunk: + r"""Creates a ``Chunk``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_create_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + chunk = generativelanguage_v1alpha.Chunk() + chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.CreateChunkRequest( + parent="parent_value", + chunk=chunk, + ) + + # Make the request + response = client.create_chunk(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.CreateChunkRequest, dict]): + The request object. Request to create a ``Chunk``. + parent (str): + Required. The name of the ``Document`` where this + ``Chunk`` will be created. Example: + ``corpora/my-corpus-123/documents/the-doc-abc`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + chunk (google.ai.generativelanguage_v1alpha.types.Chunk): + Required. The ``Chunk`` to create. + This corresponds to the ``chunk`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Chunk: + A Chunk is a subpart of a Document that is treated as an independent unit + for the purposes of vector representation and + storage. A Corpus can have a maximum of 1 million + Chunks. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent, chunk]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.CreateChunkRequest): + request = retriever_service.CreateChunkRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + if chunk is not None: + request.chunk = chunk + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.create_chunk] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def batch_create_chunks( + self, + request: Optional[ + Union[retriever_service.BatchCreateChunksRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever_service.BatchCreateChunksResponse: + r"""Batch create ``Chunk``\ s. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_batch_create_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.CreateChunkRequest() + requests.parent = "parent_value" + requests.chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.BatchCreateChunksRequest( + requests=requests, + ) + + # Make the request + response = client.batch_create_chunks(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.BatchCreateChunksRequest, dict]): + The request object. Request to batch create ``Chunk``\ s. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.BatchCreateChunksResponse: + Response from BatchCreateChunks containing a list of + created Chunks. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.BatchCreateChunksRequest): + request = retriever_service.BatchCreateChunksRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.batch_create_chunks] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_chunk( + self, + request: Optional[Union[retriever_service.GetChunkRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Chunk: + r"""Gets information about a specific ``Chunk``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_get_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetChunkRequest( + name="name_value", + ) + + # Make the request + response = client.get_chunk(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.GetChunkRequest, dict]): + The request object. Request for getting information about a specific + ``Chunk``. + name (str): + Required. The name of the ``Chunk`` to retrieve. + Example: + ``corpora/my-corpus-123/documents/the-doc-abc/chunks/some-chunk`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Chunk: + A Chunk is a subpart of a Document that is treated as an independent unit + for the purposes of vector representation and + storage. A Corpus can have a maximum of 1 million + Chunks. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.GetChunkRequest): + request = retriever_service.GetChunkRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_chunk] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def update_chunk( + self, + request: Optional[Union[retriever_service.UpdateChunkRequest, dict]] = None, + *, + chunk: Optional[retriever.Chunk] = None, + update_mask: Optional[field_mask_pb2.FieldMask] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Chunk: + r"""Updates a ``Chunk``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_update_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + chunk = generativelanguage_v1alpha.Chunk() + chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.UpdateChunkRequest( + chunk=chunk, + ) + + # Make the request + response = client.update_chunk(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.UpdateChunkRequest, dict]): + The request object. Request to update a ``Chunk``. + chunk (google.ai.generativelanguage_v1alpha.types.Chunk): + Required. The ``Chunk`` to update. + This corresponds to the ``chunk`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The list of fields to update. Currently, this + only supports updating ``custom_metadata`` and ``data``. + + This corresponds to the ``update_mask`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.Chunk: + A Chunk is a subpart of a Document that is treated as an independent unit + for the purposes of vector representation and + storage. A Corpus can have a maximum of 1 million + Chunks. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([chunk, update_mask]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.UpdateChunkRequest): + request = retriever_service.UpdateChunkRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if chunk is not None: + request.chunk = chunk + if update_mask is not None: + request.update_mask = update_mask + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.update_chunk] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata( + (("chunk.name", request.chunk.name),) + ), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def batch_update_chunks( + self, + request: Optional[ + Union[retriever_service.BatchUpdateChunksRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever_service.BatchUpdateChunksResponse: + r"""Batch update ``Chunk``\ s. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_batch_update_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.UpdateChunkRequest() + requests.chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.BatchUpdateChunksRequest( + requests=requests, + ) + + # Make the request + response = client.batch_update_chunks(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.BatchUpdateChunksRequest, dict]): + The request object. Request to batch update ``Chunk``\ s. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.BatchUpdateChunksResponse: + Response from BatchUpdateChunks containing a list of + updated Chunks. + + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.BatchUpdateChunksRequest): + request = retriever_service.BatchUpdateChunksRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.batch_update_chunks] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def delete_chunk( + self, + request: Optional[Union[retriever_service.DeleteChunkRequest, dict]] = None, + *, + name: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Deletes a ``Chunk``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_delete_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteChunkRequest( + name="name_value", + ) + + # Make the request + client.delete_chunk(request=request) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.DeleteChunkRequest, dict]): + The request object. Request to delete a ``Chunk``. + name (str): + Required. The resource name of the ``Chunk`` to delete. + Example: + ``corpora/my-corpus-123/documents/the-doc-abc/chunks/some-chunk`` + + This corresponds to the ``name`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([name]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.DeleteChunkRequest): + request = retriever_service.DeleteChunkRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if name is not None: + request.name = name + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.delete_chunk] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def batch_delete_chunks( + self, + request: Optional[ + Union[retriever_service.BatchDeleteChunksRequest, dict] + ] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> None: + r"""Batch delete ``Chunk``\ s. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_batch_delete_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.DeleteChunkRequest() + requests.name = "name_value" + + request = generativelanguage_v1alpha.BatchDeleteChunksRequest( + requests=requests, + ) + + # Make the request + client.batch_delete_chunks(request=request) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.BatchDeleteChunksRequest, dict]): + The request object. Request to batch delete ``Chunk``\ s. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + # Create or coerce a protobuf request object. + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.BatchDeleteChunksRequest): + request = retriever_service.BatchDeleteChunksRequest(request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.batch_delete_chunks] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + def list_chunks( + self, + request: Optional[Union[retriever_service.ListChunksRequest, dict]] = None, + *, + parent: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> pagers.ListChunksPager: + r"""Lists all ``Chunk``\ s in a ``Document``. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_list_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListChunksRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_chunks(request=request) + + # Handle the response + for response in page_result: + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.ListChunksRequest, dict]): + The request object. Request for listing ``Chunk``\ s. + parent (str): + Required. The name of the ``Document`` containing + ``Chunk``\ s. Example: + ``corpora/my-corpus-123/documents/the-doc-abc`` + + This corresponds to the ``parent`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.services.retriever_service.pagers.ListChunksPager: + Response from ListChunks containing a paginated list of Chunks. + The Chunks are sorted by ascending chunk.create_time. + + Iterating over this object will yield results and + resolve additional pages automatically. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([parent]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, retriever_service.ListChunksRequest): + request = retriever_service.ListChunksRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if parent is not None: + request.parent = parent + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_chunks] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", request.parent),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # This method is paged; wrap the response in a pager, which provides + # an `__iter__` convenience method. + response = pagers.ListChunksPager( + method=rpc, + request=request, + response=response, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def __enter__(self) -> "RetrieverServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("RetrieverServiceClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/pagers.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/pagers.py new file mode 100644 index 000000000000..c2428ed2dca7 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/pagers.py @@ -0,0 +1,509 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + Any, + AsyncIterator, + Awaitable, + Callable, + Iterator, + Optional, + Sequence, + Tuple, + Union, +) + +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.api_core import retry_async as retries_async + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] + OptionalAsyncRetry = Union[ + retries_async.AsyncRetry, gapic_v1.method._MethodDefault, None + ] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + OptionalAsyncRetry = Union[retries_async.AsyncRetry, object, None] # type: ignore + +from google.ai.generativelanguage_v1alpha.types import retriever, retriever_service + + +class ListCorporaPager: + """A pager for iterating through ``list_corpora`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListCorporaResponse` object, and + provides an ``__iter__`` method to iterate through its + ``corpora`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListCorpora`` requests and continue to iterate + through the ``corpora`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListCorporaResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., retriever_service.ListCorporaResponse], + request: retriever_service.ListCorporaRequest, + response: retriever_service.ListCorporaResponse, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListCorporaRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListCorporaResponse): + The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = retriever_service.ListCorporaRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[retriever_service.ListCorporaResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __iter__(self) -> Iterator[retriever.Corpus]: + for page in self.pages: + yield from page.corpora + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListCorporaAsyncPager: + """A pager for iterating through ``list_corpora`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListCorporaResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``corpora`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListCorpora`` requests and continue to iterate + through the ``corpora`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListCorporaResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[retriever_service.ListCorporaResponse]], + request: retriever_service.ListCorporaRequest, + response: retriever_service.ListCorporaResponse, + *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListCorporaRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListCorporaResponse): + The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = retriever_service.ListCorporaRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[retriever_service.ListCorporaResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __aiter__(self) -> AsyncIterator[retriever.Corpus]: + async def async_generator(): + async for page in self.pages: + for response in page.corpora: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListDocumentsPager: + """A pager for iterating through ``list_documents`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListDocumentsResponse` object, and + provides an ``__iter__`` method to iterate through its + ``documents`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListDocuments`` requests and continue to iterate + through the ``documents`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListDocumentsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., retriever_service.ListDocumentsResponse], + request: retriever_service.ListDocumentsRequest, + response: retriever_service.ListDocumentsResponse, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListDocumentsRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListDocumentsResponse): + The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = retriever_service.ListDocumentsRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[retriever_service.ListDocumentsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __iter__(self) -> Iterator[retriever.Document]: + for page in self.pages: + yield from page.documents + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListDocumentsAsyncPager: + """A pager for iterating through ``list_documents`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListDocumentsResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``documents`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListDocuments`` requests and continue to iterate + through the ``documents`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListDocumentsResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[retriever_service.ListDocumentsResponse]], + request: retriever_service.ListDocumentsRequest, + response: retriever_service.ListDocumentsResponse, + *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListDocumentsRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListDocumentsResponse): + The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = retriever_service.ListDocumentsRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[retriever_service.ListDocumentsResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __aiter__(self) -> AsyncIterator[retriever.Document]: + async def async_generator(): + async for page in self.pages: + for response in page.documents: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListChunksPager: + """A pager for iterating through ``list_chunks`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListChunksResponse` object, and + provides an ``__iter__`` method to iterate through its + ``chunks`` field. + + If there are more pages, the ``__iter__`` method will make additional + ``ListChunks`` requests and continue to iterate + through the ``chunks`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListChunksResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., retriever_service.ListChunksResponse], + request: retriever_service.ListChunksRequest, + response: retriever_service.ListChunksResponse, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiate the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListChunksRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListChunksResponse): + The initial response object. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = retriever_service.ListChunksRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + def pages(self) -> Iterator[retriever_service.ListChunksResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __iter__(self) -> Iterator[retriever.Chunk]: + for page in self.pages: + yield from page.chunks + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) + + +class ListChunksAsyncPager: + """A pager for iterating through ``list_chunks`` requests. + + This class thinly wraps an initial + :class:`google.ai.generativelanguage_v1alpha.types.ListChunksResponse` object, and + provides an ``__aiter__`` method to iterate through its + ``chunks`` field. + + If there are more pages, the ``__aiter__`` method will make additional + ``ListChunks`` requests and continue to iterate + through the ``chunks`` field on the + corresponding responses. + + All the usual :class:`google.ai.generativelanguage_v1alpha.types.ListChunksResponse` + attributes are available on the pager. If multiple requests are made, only + the most recent response is retained, and thus used for attribute lookup. + """ + + def __init__( + self, + method: Callable[..., Awaitable[retriever_service.ListChunksResponse]], + request: retriever_service.ListChunksRequest, + response: retriever_service.ListChunksResponse, + *, + retry: OptionalAsyncRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = () + ): + """Instantiates the pager. + + Args: + method (Callable): The method that was originally called, and + which instantiated this pager. + request (google.ai.generativelanguage_v1alpha.types.ListChunksRequest): + The initial request object. + response (google.ai.generativelanguage_v1alpha.types.ListChunksResponse): + The initial response object. + retry (google.api_core.retry.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + self._method = method + self._request = retriever_service.ListChunksRequest(request) + self._response = response + self._retry = retry + self._timeout = timeout + self._metadata = metadata + + def __getattr__(self, name: str) -> Any: + return getattr(self._response, name) + + @property + async def pages(self) -> AsyncIterator[retriever_service.ListChunksResponse]: + yield self._response + while self._response.next_page_token: + self._request.page_token = self._response.next_page_token + self._response = await self._method( + self._request, + retry=self._retry, + timeout=self._timeout, + metadata=self._metadata, + ) + yield self._response + + def __aiter__(self) -> AsyncIterator[retriever.Chunk]: + async def async_generator(): + async for page in self.pages: + for response in page.chunks: + yield response + + return async_generator() + + def __repr__(self) -> str: + return "{0}<{1!r}>".format(self.__class__.__name__, self._response) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/README.rst b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/README.rst new file mode 100644 index 000000000000..5f241959ab33 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/README.rst @@ -0,0 +1,9 @@ + +transport inheritance structure +_______________________________ + +`RetrieverServiceTransport` is the ABC for all transports. +- public child `RetrieverServiceGrpcTransport` for sync gRPC transport (defined in `grpc.py`). +- public child `RetrieverServiceGrpcAsyncIOTransport` for async gRPC transport (defined in `grpc_asyncio.py`). +- private child `_BaseRetrieverServiceRestTransport` for base REST transport with inner classes `_BaseMETHOD` (defined in `rest_base.py`). +- public child `RetrieverServiceRestTransport` for sync REST transport with inner classes `METHOD` derived from the parent's corresponding `_BaseMETHOD` classes (defined in `rest.py`). diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/__init__.py new file mode 100644 index 000000000000..81046f334905 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/__init__.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import RetrieverServiceTransport +from .grpc import RetrieverServiceGrpcTransport +from .grpc_asyncio import RetrieverServiceGrpcAsyncIOTransport +from .rest import RetrieverServiceRestInterceptor, RetrieverServiceRestTransport + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[RetrieverServiceTransport]] +_transport_registry["grpc"] = RetrieverServiceGrpcTransport +_transport_registry["grpc_asyncio"] = RetrieverServiceGrpcAsyncIOTransport +_transport_registry["rest"] = RetrieverServiceRestTransport + +__all__ = ( + "RetrieverServiceTransport", + "RetrieverServiceGrpcTransport", + "RetrieverServiceGrpcAsyncIOTransport", + "RetrieverServiceRestTransport", + "RetrieverServiceRestInterceptor", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/base.py new file mode 100644 index 000000000000..92bec7927842 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/base.py @@ -0,0 +1,481 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account # type: ignore +from google.protobuf import empty_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version +from google.ai.generativelanguage_v1alpha.types import retriever, retriever_service + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +class RetrieverServiceTransport(abc.ABC): + """Abstract transport class for RetrieverService.""" + + AUTH_SCOPES = () + + DEFAULT_HOST: str = "generativelanguage.googleapis.com" + + def __init__( + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + if not hasattr(self, "_ignore_credentials"): + self._ignore_credentials: bool = False + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) + elif credentials is None and not self._ignore_credentials: + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + @property + def host(self): + return self._host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.create_corpus: gapic_v1.method.wrap_method( + self.create_corpus, + default_timeout=None, + client_info=client_info, + ), + self.get_corpus: gapic_v1.method.wrap_method( + self.get_corpus, + default_timeout=None, + client_info=client_info, + ), + self.update_corpus: gapic_v1.method.wrap_method( + self.update_corpus, + default_timeout=None, + client_info=client_info, + ), + self.delete_corpus: gapic_v1.method.wrap_method( + self.delete_corpus, + default_timeout=None, + client_info=client_info, + ), + self.list_corpora: gapic_v1.method.wrap_method( + self.list_corpora, + default_timeout=None, + client_info=client_info, + ), + self.query_corpus: gapic_v1.method.wrap_method( + self.query_corpus, + default_timeout=None, + client_info=client_info, + ), + self.create_document: gapic_v1.method.wrap_method( + self.create_document, + default_timeout=None, + client_info=client_info, + ), + self.get_document: gapic_v1.method.wrap_method( + self.get_document, + default_timeout=None, + client_info=client_info, + ), + self.update_document: gapic_v1.method.wrap_method( + self.update_document, + default_timeout=None, + client_info=client_info, + ), + self.delete_document: gapic_v1.method.wrap_method( + self.delete_document, + default_timeout=None, + client_info=client_info, + ), + self.list_documents: gapic_v1.method.wrap_method( + self.list_documents, + default_timeout=None, + client_info=client_info, + ), + self.query_document: gapic_v1.method.wrap_method( + self.query_document, + default_timeout=None, + client_info=client_info, + ), + self.create_chunk: gapic_v1.method.wrap_method( + self.create_chunk, + default_timeout=None, + client_info=client_info, + ), + self.batch_create_chunks: gapic_v1.method.wrap_method( + self.batch_create_chunks, + default_timeout=None, + client_info=client_info, + ), + self.get_chunk: gapic_v1.method.wrap_method( + self.get_chunk, + default_timeout=None, + client_info=client_info, + ), + self.update_chunk: gapic_v1.method.wrap_method( + self.update_chunk, + default_timeout=None, + client_info=client_info, + ), + self.batch_update_chunks: gapic_v1.method.wrap_method( + self.batch_update_chunks, + default_timeout=None, + client_info=client_info, + ), + self.delete_chunk: gapic_v1.method.wrap_method( + self.delete_chunk, + default_timeout=None, + client_info=client_info, + ), + self.batch_delete_chunks: gapic_v1.method.wrap_method( + self.batch_delete_chunks, + default_timeout=None, + client_info=client_info, + ), + self.list_chunks: gapic_v1.method.wrap_method( + self.list_chunks, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: gapic_v1.method.wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: gapic_v1.method.wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def create_corpus( + self, + ) -> Callable[ + [retriever_service.CreateCorpusRequest], + Union[retriever.Corpus, Awaitable[retriever.Corpus]], + ]: + raise NotImplementedError() + + @property + def get_corpus( + self, + ) -> Callable[ + [retriever_service.GetCorpusRequest], + Union[retriever.Corpus, Awaitable[retriever.Corpus]], + ]: + raise NotImplementedError() + + @property + def update_corpus( + self, + ) -> Callable[ + [retriever_service.UpdateCorpusRequest], + Union[retriever.Corpus, Awaitable[retriever.Corpus]], + ]: + raise NotImplementedError() + + @property + def delete_corpus( + self, + ) -> Callable[ + [retriever_service.DeleteCorpusRequest], + Union[empty_pb2.Empty, Awaitable[empty_pb2.Empty]], + ]: + raise NotImplementedError() + + @property + def list_corpora( + self, + ) -> Callable[ + [retriever_service.ListCorporaRequest], + Union[ + retriever_service.ListCorporaResponse, + Awaitable[retriever_service.ListCorporaResponse], + ], + ]: + raise NotImplementedError() + + @property + def query_corpus( + self, + ) -> Callable[ + [retriever_service.QueryCorpusRequest], + Union[ + retriever_service.QueryCorpusResponse, + Awaitable[retriever_service.QueryCorpusResponse], + ], + ]: + raise NotImplementedError() + + @property + def create_document( + self, + ) -> Callable[ + [retriever_service.CreateDocumentRequest], + Union[retriever.Document, Awaitable[retriever.Document]], + ]: + raise NotImplementedError() + + @property + def get_document( + self, + ) -> Callable[ + [retriever_service.GetDocumentRequest], + Union[retriever.Document, Awaitable[retriever.Document]], + ]: + raise NotImplementedError() + + @property + def update_document( + self, + ) -> Callable[ + [retriever_service.UpdateDocumentRequest], + Union[retriever.Document, Awaitable[retriever.Document]], + ]: + raise NotImplementedError() + + @property + def delete_document( + self, + ) -> Callable[ + [retriever_service.DeleteDocumentRequest], + Union[empty_pb2.Empty, Awaitable[empty_pb2.Empty]], + ]: + raise NotImplementedError() + + @property + def list_documents( + self, + ) -> Callable[ + [retriever_service.ListDocumentsRequest], + Union[ + retriever_service.ListDocumentsResponse, + Awaitable[retriever_service.ListDocumentsResponse], + ], + ]: + raise NotImplementedError() + + @property + def query_document( + self, + ) -> Callable[ + [retriever_service.QueryDocumentRequest], + Union[ + retriever_service.QueryDocumentResponse, + Awaitable[retriever_service.QueryDocumentResponse], + ], + ]: + raise NotImplementedError() + + @property + def create_chunk( + self, + ) -> Callable[ + [retriever_service.CreateChunkRequest], + Union[retriever.Chunk, Awaitable[retriever.Chunk]], + ]: + raise NotImplementedError() + + @property + def batch_create_chunks( + self, + ) -> Callable[ + [retriever_service.BatchCreateChunksRequest], + Union[ + retriever_service.BatchCreateChunksResponse, + Awaitable[retriever_service.BatchCreateChunksResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_chunk( + self, + ) -> Callable[ + [retriever_service.GetChunkRequest], + Union[retriever.Chunk, Awaitable[retriever.Chunk]], + ]: + raise NotImplementedError() + + @property + def update_chunk( + self, + ) -> Callable[ + [retriever_service.UpdateChunkRequest], + Union[retriever.Chunk, Awaitable[retriever.Chunk]], + ]: + raise NotImplementedError() + + @property + def batch_update_chunks( + self, + ) -> Callable[ + [retriever_service.BatchUpdateChunksRequest], + Union[ + retriever_service.BatchUpdateChunksResponse, + Awaitable[retriever_service.BatchUpdateChunksResponse], + ], + ]: + raise NotImplementedError() + + @property + def delete_chunk( + self, + ) -> Callable[ + [retriever_service.DeleteChunkRequest], + Union[empty_pb2.Empty, Awaitable[empty_pb2.Empty]], + ]: + raise NotImplementedError() + + @property + def batch_delete_chunks( + self, + ) -> Callable[ + [retriever_service.BatchDeleteChunksRequest], + Union[empty_pb2.Empty, Awaitable[empty_pb2.Empty]], + ]: + raise NotImplementedError() + + @property + def list_chunks( + self, + ) -> Callable[ + [retriever_service.ListChunksRequest], + Union[ + retriever_service.ListChunksResponse, + Awaitable[retriever_service.ListChunksResponse], + ], + ]: + raise NotImplementedError() + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], + Union[ + operations_pb2.ListOperationsResponse, + Awaitable[operations_pb2.ListOperationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_operation( + self, + ) -> Callable[ + [operations_pb2.GetOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ("RetrieverServiceTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/grpc.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/grpc.py new file mode 100644 index 000000000000..5467ec9eb439 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/grpc.py @@ -0,0 +1,908 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import logging as std_logging +import pickle +from typing import Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import gapic_v1, grpc_helpers +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import retriever, retriever_service + +from .base import DEFAULT_CLIENT_INFO, RetrieverServiceTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientInterceptor(grpc.UnaryUnaryClientInterceptor): # pragma: NO COVER + def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": client_call_details.method, + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + + response = continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = response.result() + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response for {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": client_call_details.method, + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class RetrieverServiceGrpcTransport(RetrieverServiceTransport): + """gRPC backend transport for RetrieverService. + + An API for semantic search over a corpus of user uploaded + content. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, grpc.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientInterceptor() + self._logged_channel = grpc.intercept_channel( + self._grpc_channel, self._interceptor + ) + + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service.""" + return self._grpc_channel + + @property + def create_corpus( + self, + ) -> Callable[[retriever_service.CreateCorpusRequest], retriever.Corpus]: + r"""Return a callable for the create corpus method over gRPC. + + Creates an empty ``Corpus``. + + Returns: + Callable[[~.CreateCorpusRequest], + ~.Corpus]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_corpus" not in self._stubs: + self._stubs["create_corpus"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/CreateCorpus", + request_serializer=retriever_service.CreateCorpusRequest.serialize, + response_deserializer=retriever.Corpus.deserialize, + ) + return self._stubs["create_corpus"] + + @property + def get_corpus( + self, + ) -> Callable[[retriever_service.GetCorpusRequest], retriever.Corpus]: + r"""Return a callable for the get corpus method over gRPC. + + Gets information about a specific ``Corpus``. + + Returns: + Callable[[~.GetCorpusRequest], + ~.Corpus]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_corpus" not in self._stubs: + self._stubs["get_corpus"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/GetCorpus", + request_serializer=retriever_service.GetCorpusRequest.serialize, + response_deserializer=retriever.Corpus.deserialize, + ) + return self._stubs["get_corpus"] + + @property + def update_corpus( + self, + ) -> Callable[[retriever_service.UpdateCorpusRequest], retriever.Corpus]: + r"""Return a callable for the update corpus method over gRPC. + + Updates a ``Corpus``. + + Returns: + Callable[[~.UpdateCorpusRequest], + ~.Corpus]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_corpus" not in self._stubs: + self._stubs["update_corpus"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/UpdateCorpus", + request_serializer=retriever_service.UpdateCorpusRequest.serialize, + response_deserializer=retriever.Corpus.deserialize, + ) + return self._stubs["update_corpus"] + + @property + def delete_corpus( + self, + ) -> Callable[[retriever_service.DeleteCorpusRequest], empty_pb2.Empty]: + r"""Return a callable for the delete corpus method over gRPC. + + Deletes a ``Corpus``. + + Returns: + Callable[[~.DeleteCorpusRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_corpus" not in self._stubs: + self._stubs["delete_corpus"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/DeleteCorpus", + request_serializer=retriever_service.DeleteCorpusRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["delete_corpus"] + + @property + def list_corpora( + self, + ) -> Callable[ + [retriever_service.ListCorporaRequest], retriever_service.ListCorporaResponse + ]: + r"""Return a callable for the list corpora method over gRPC. + + Lists all ``Corpora`` owned by the user. + + Returns: + Callable[[~.ListCorporaRequest], + ~.ListCorporaResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_corpora" not in self._stubs: + self._stubs["list_corpora"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/ListCorpora", + request_serializer=retriever_service.ListCorporaRequest.serialize, + response_deserializer=retriever_service.ListCorporaResponse.deserialize, + ) + return self._stubs["list_corpora"] + + @property + def query_corpus( + self, + ) -> Callable[ + [retriever_service.QueryCorpusRequest], retriever_service.QueryCorpusResponse + ]: + r"""Return a callable for the query corpus method over gRPC. + + Performs semantic search over a ``Corpus``. + + Returns: + Callable[[~.QueryCorpusRequest], + ~.QueryCorpusResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "query_corpus" not in self._stubs: + self._stubs["query_corpus"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/QueryCorpus", + request_serializer=retriever_service.QueryCorpusRequest.serialize, + response_deserializer=retriever_service.QueryCorpusResponse.deserialize, + ) + return self._stubs["query_corpus"] + + @property + def create_document( + self, + ) -> Callable[[retriever_service.CreateDocumentRequest], retriever.Document]: + r"""Return a callable for the create document method over gRPC. + + Creates an empty ``Document``. + + Returns: + Callable[[~.CreateDocumentRequest], + ~.Document]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_document" not in self._stubs: + self._stubs["create_document"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/CreateDocument", + request_serializer=retriever_service.CreateDocumentRequest.serialize, + response_deserializer=retriever.Document.deserialize, + ) + return self._stubs["create_document"] + + @property + def get_document( + self, + ) -> Callable[[retriever_service.GetDocumentRequest], retriever.Document]: + r"""Return a callable for the get document method over gRPC. + + Gets information about a specific ``Document``. + + Returns: + Callable[[~.GetDocumentRequest], + ~.Document]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_document" not in self._stubs: + self._stubs["get_document"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/GetDocument", + request_serializer=retriever_service.GetDocumentRequest.serialize, + response_deserializer=retriever.Document.deserialize, + ) + return self._stubs["get_document"] + + @property + def update_document( + self, + ) -> Callable[[retriever_service.UpdateDocumentRequest], retriever.Document]: + r"""Return a callable for the update document method over gRPC. + + Updates a ``Document``. + + Returns: + Callable[[~.UpdateDocumentRequest], + ~.Document]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_document" not in self._stubs: + self._stubs["update_document"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/UpdateDocument", + request_serializer=retriever_service.UpdateDocumentRequest.serialize, + response_deserializer=retriever.Document.deserialize, + ) + return self._stubs["update_document"] + + @property + def delete_document( + self, + ) -> Callable[[retriever_service.DeleteDocumentRequest], empty_pb2.Empty]: + r"""Return a callable for the delete document method over gRPC. + + Deletes a ``Document``. + + Returns: + Callable[[~.DeleteDocumentRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_document" not in self._stubs: + self._stubs["delete_document"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/DeleteDocument", + request_serializer=retriever_service.DeleteDocumentRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["delete_document"] + + @property + def list_documents( + self, + ) -> Callable[ + [retriever_service.ListDocumentsRequest], + retriever_service.ListDocumentsResponse, + ]: + r"""Return a callable for the list documents method over gRPC. + + Lists all ``Document``\ s in a ``Corpus``. + + Returns: + Callable[[~.ListDocumentsRequest], + ~.ListDocumentsResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_documents" not in self._stubs: + self._stubs["list_documents"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/ListDocuments", + request_serializer=retriever_service.ListDocumentsRequest.serialize, + response_deserializer=retriever_service.ListDocumentsResponse.deserialize, + ) + return self._stubs["list_documents"] + + @property + def query_document( + self, + ) -> Callable[ + [retriever_service.QueryDocumentRequest], + retriever_service.QueryDocumentResponse, + ]: + r"""Return a callable for the query document method over gRPC. + + Performs semantic search over a ``Document``. + + Returns: + Callable[[~.QueryDocumentRequest], + ~.QueryDocumentResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "query_document" not in self._stubs: + self._stubs["query_document"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/QueryDocument", + request_serializer=retriever_service.QueryDocumentRequest.serialize, + response_deserializer=retriever_service.QueryDocumentResponse.deserialize, + ) + return self._stubs["query_document"] + + @property + def create_chunk( + self, + ) -> Callable[[retriever_service.CreateChunkRequest], retriever.Chunk]: + r"""Return a callable for the create chunk method over gRPC. + + Creates a ``Chunk``. + + Returns: + Callable[[~.CreateChunkRequest], + ~.Chunk]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_chunk" not in self._stubs: + self._stubs["create_chunk"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/CreateChunk", + request_serializer=retriever_service.CreateChunkRequest.serialize, + response_deserializer=retriever.Chunk.deserialize, + ) + return self._stubs["create_chunk"] + + @property + def batch_create_chunks( + self, + ) -> Callable[ + [retriever_service.BatchCreateChunksRequest], + retriever_service.BatchCreateChunksResponse, + ]: + r"""Return a callable for the batch create chunks method over gRPC. + + Batch create ``Chunk``\ s. + + Returns: + Callable[[~.BatchCreateChunksRequest], + ~.BatchCreateChunksResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_create_chunks" not in self._stubs: + self._stubs["batch_create_chunks"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/BatchCreateChunks", + request_serializer=retriever_service.BatchCreateChunksRequest.serialize, + response_deserializer=retriever_service.BatchCreateChunksResponse.deserialize, + ) + return self._stubs["batch_create_chunks"] + + @property + def get_chunk( + self, + ) -> Callable[[retriever_service.GetChunkRequest], retriever.Chunk]: + r"""Return a callable for the get chunk method over gRPC. + + Gets information about a specific ``Chunk``. + + Returns: + Callable[[~.GetChunkRequest], + ~.Chunk]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_chunk" not in self._stubs: + self._stubs["get_chunk"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/GetChunk", + request_serializer=retriever_service.GetChunkRequest.serialize, + response_deserializer=retriever.Chunk.deserialize, + ) + return self._stubs["get_chunk"] + + @property + def update_chunk( + self, + ) -> Callable[[retriever_service.UpdateChunkRequest], retriever.Chunk]: + r"""Return a callable for the update chunk method over gRPC. + + Updates a ``Chunk``. + + Returns: + Callable[[~.UpdateChunkRequest], + ~.Chunk]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_chunk" not in self._stubs: + self._stubs["update_chunk"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/UpdateChunk", + request_serializer=retriever_service.UpdateChunkRequest.serialize, + response_deserializer=retriever.Chunk.deserialize, + ) + return self._stubs["update_chunk"] + + @property + def batch_update_chunks( + self, + ) -> Callable[ + [retriever_service.BatchUpdateChunksRequest], + retriever_service.BatchUpdateChunksResponse, + ]: + r"""Return a callable for the batch update chunks method over gRPC. + + Batch update ``Chunk``\ s. + + Returns: + Callable[[~.BatchUpdateChunksRequest], + ~.BatchUpdateChunksResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_update_chunks" not in self._stubs: + self._stubs["batch_update_chunks"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/BatchUpdateChunks", + request_serializer=retriever_service.BatchUpdateChunksRequest.serialize, + response_deserializer=retriever_service.BatchUpdateChunksResponse.deserialize, + ) + return self._stubs["batch_update_chunks"] + + @property + def delete_chunk( + self, + ) -> Callable[[retriever_service.DeleteChunkRequest], empty_pb2.Empty]: + r"""Return a callable for the delete chunk method over gRPC. + + Deletes a ``Chunk``. + + Returns: + Callable[[~.DeleteChunkRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_chunk" not in self._stubs: + self._stubs["delete_chunk"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/DeleteChunk", + request_serializer=retriever_service.DeleteChunkRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["delete_chunk"] + + @property + def batch_delete_chunks( + self, + ) -> Callable[[retriever_service.BatchDeleteChunksRequest], empty_pb2.Empty]: + r"""Return a callable for the batch delete chunks method over gRPC. + + Batch delete ``Chunk``\ s. + + Returns: + Callable[[~.BatchDeleteChunksRequest], + ~.Empty]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_delete_chunks" not in self._stubs: + self._stubs["batch_delete_chunks"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/BatchDeleteChunks", + request_serializer=retriever_service.BatchDeleteChunksRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["batch_delete_chunks"] + + @property + def list_chunks( + self, + ) -> Callable[ + [retriever_service.ListChunksRequest], retriever_service.ListChunksResponse + ]: + r"""Return a callable for the list chunks method over gRPC. + + Lists all ``Chunk``\ s in a ``Document``. + + Returns: + Callable[[~.ListChunksRequest], + ~.ListChunksResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_chunks" not in self._stubs: + self._stubs["list_chunks"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/ListChunks", + request_serializer=retriever_service.ListChunksRequest.serialize, + response_deserializer=retriever_service.ListChunksResponse.deserialize, + ) + return self._stubs["list_chunks"] + + def close(self): + self._logged_channel.close() + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ("RetrieverServiceGrpcTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/grpc_asyncio.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..19135680af8d --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/grpc_asyncio.py @@ -0,0 +1,1048 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import json +import logging as std_logging +import pickle +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, grpc_helpers_async +from google.api_core import retry_async as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +from grpc.experimental import aio # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import retriever, retriever_service + +from .base import DEFAULT_CLIENT_INFO, RetrieverServiceTransport +from .grpc import RetrieverServiceGrpcTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientAIOInterceptor( + grpc.aio.UnaryUnaryClientInterceptor +): # pragma: NO COVER + async def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": str(client_call_details.method), + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + response = await continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = await response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = await response + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response to rpc {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": str(client_call_details.method), + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class RetrieverServiceGrpcAsyncIOTransport(RetrieverServiceTransport): + """gRPC AsyncIO backend transport for RetrieverService. + + An API for semantic search over a corpus of user uploaded + content. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, aio.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientAIOInterceptor() + self._grpc_channel._unary_unary_interceptors.append(self._interceptor) + self._logged_channel = self._grpc_channel + self._wrap_with_kind = ( + "kind" in inspect.signature(gapic_v1.method_async.wrap_method).parameters + ) + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def create_corpus( + self, + ) -> Callable[[retriever_service.CreateCorpusRequest], Awaitable[retriever.Corpus]]: + r"""Return a callable for the create corpus method over gRPC. + + Creates an empty ``Corpus``. + + Returns: + Callable[[~.CreateCorpusRequest], + Awaitable[~.Corpus]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_corpus" not in self._stubs: + self._stubs["create_corpus"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/CreateCorpus", + request_serializer=retriever_service.CreateCorpusRequest.serialize, + response_deserializer=retriever.Corpus.deserialize, + ) + return self._stubs["create_corpus"] + + @property + def get_corpus( + self, + ) -> Callable[[retriever_service.GetCorpusRequest], Awaitable[retriever.Corpus]]: + r"""Return a callable for the get corpus method over gRPC. + + Gets information about a specific ``Corpus``. + + Returns: + Callable[[~.GetCorpusRequest], + Awaitable[~.Corpus]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_corpus" not in self._stubs: + self._stubs["get_corpus"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/GetCorpus", + request_serializer=retriever_service.GetCorpusRequest.serialize, + response_deserializer=retriever.Corpus.deserialize, + ) + return self._stubs["get_corpus"] + + @property + def update_corpus( + self, + ) -> Callable[[retriever_service.UpdateCorpusRequest], Awaitable[retriever.Corpus]]: + r"""Return a callable for the update corpus method over gRPC. + + Updates a ``Corpus``. + + Returns: + Callable[[~.UpdateCorpusRequest], + Awaitable[~.Corpus]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_corpus" not in self._stubs: + self._stubs["update_corpus"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/UpdateCorpus", + request_serializer=retriever_service.UpdateCorpusRequest.serialize, + response_deserializer=retriever.Corpus.deserialize, + ) + return self._stubs["update_corpus"] + + @property + def delete_corpus( + self, + ) -> Callable[[retriever_service.DeleteCorpusRequest], Awaitable[empty_pb2.Empty]]: + r"""Return a callable for the delete corpus method over gRPC. + + Deletes a ``Corpus``. + + Returns: + Callable[[~.DeleteCorpusRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_corpus" not in self._stubs: + self._stubs["delete_corpus"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/DeleteCorpus", + request_serializer=retriever_service.DeleteCorpusRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["delete_corpus"] + + @property + def list_corpora( + self, + ) -> Callable[ + [retriever_service.ListCorporaRequest], + Awaitable[retriever_service.ListCorporaResponse], + ]: + r"""Return a callable for the list corpora method over gRPC. + + Lists all ``Corpora`` owned by the user. + + Returns: + Callable[[~.ListCorporaRequest], + Awaitable[~.ListCorporaResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_corpora" not in self._stubs: + self._stubs["list_corpora"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/ListCorpora", + request_serializer=retriever_service.ListCorporaRequest.serialize, + response_deserializer=retriever_service.ListCorporaResponse.deserialize, + ) + return self._stubs["list_corpora"] + + @property + def query_corpus( + self, + ) -> Callable[ + [retriever_service.QueryCorpusRequest], + Awaitable[retriever_service.QueryCorpusResponse], + ]: + r"""Return a callable for the query corpus method over gRPC. + + Performs semantic search over a ``Corpus``. + + Returns: + Callable[[~.QueryCorpusRequest], + Awaitable[~.QueryCorpusResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "query_corpus" not in self._stubs: + self._stubs["query_corpus"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/QueryCorpus", + request_serializer=retriever_service.QueryCorpusRequest.serialize, + response_deserializer=retriever_service.QueryCorpusResponse.deserialize, + ) + return self._stubs["query_corpus"] + + @property + def create_document( + self, + ) -> Callable[ + [retriever_service.CreateDocumentRequest], Awaitable[retriever.Document] + ]: + r"""Return a callable for the create document method over gRPC. + + Creates an empty ``Document``. + + Returns: + Callable[[~.CreateDocumentRequest], + Awaitable[~.Document]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_document" not in self._stubs: + self._stubs["create_document"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/CreateDocument", + request_serializer=retriever_service.CreateDocumentRequest.serialize, + response_deserializer=retriever.Document.deserialize, + ) + return self._stubs["create_document"] + + @property + def get_document( + self, + ) -> Callable[ + [retriever_service.GetDocumentRequest], Awaitable[retriever.Document] + ]: + r"""Return a callable for the get document method over gRPC. + + Gets information about a specific ``Document``. + + Returns: + Callable[[~.GetDocumentRequest], + Awaitable[~.Document]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_document" not in self._stubs: + self._stubs["get_document"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/GetDocument", + request_serializer=retriever_service.GetDocumentRequest.serialize, + response_deserializer=retriever.Document.deserialize, + ) + return self._stubs["get_document"] + + @property + def update_document( + self, + ) -> Callable[ + [retriever_service.UpdateDocumentRequest], Awaitable[retriever.Document] + ]: + r"""Return a callable for the update document method over gRPC. + + Updates a ``Document``. + + Returns: + Callable[[~.UpdateDocumentRequest], + Awaitable[~.Document]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_document" not in self._stubs: + self._stubs["update_document"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/UpdateDocument", + request_serializer=retriever_service.UpdateDocumentRequest.serialize, + response_deserializer=retriever.Document.deserialize, + ) + return self._stubs["update_document"] + + @property + def delete_document( + self, + ) -> Callable[ + [retriever_service.DeleteDocumentRequest], Awaitable[empty_pb2.Empty] + ]: + r"""Return a callable for the delete document method over gRPC. + + Deletes a ``Document``. + + Returns: + Callable[[~.DeleteDocumentRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_document" not in self._stubs: + self._stubs["delete_document"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/DeleteDocument", + request_serializer=retriever_service.DeleteDocumentRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["delete_document"] + + @property + def list_documents( + self, + ) -> Callable[ + [retriever_service.ListDocumentsRequest], + Awaitable[retriever_service.ListDocumentsResponse], + ]: + r"""Return a callable for the list documents method over gRPC. + + Lists all ``Document``\ s in a ``Corpus``. + + Returns: + Callable[[~.ListDocumentsRequest], + Awaitable[~.ListDocumentsResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_documents" not in self._stubs: + self._stubs["list_documents"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/ListDocuments", + request_serializer=retriever_service.ListDocumentsRequest.serialize, + response_deserializer=retriever_service.ListDocumentsResponse.deserialize, + ) + return self._stubs["list_documents"] + + @property + def query_document( + self, + ) -> Callable[ + [retriever_service.QueryDocumentRequest], + Awaitable[retriever_service.QueryDocumentResponse], + ]: + r"""Return a callable for the query document method over gRPC. + + Performs semantic search over a ``Document``. + + Returns: + Callable[[~.QueryDocumentRequest], + Awaitable[~.QueryDocumentResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "query_document" not in self._stubs: + self._stubs["query_document"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/QueryDocument", + request_serializer=retriever_service.QueryDocumentRequest.serialize, + response_deserializer=retriever_service.QueryDocumentResponse.deserialize, + ) + return self._stubs["query_document"] + + @property + def create_chunk( + self, + ) -> Callable[[retriever_service.CreateChunkRequest], Awaitable[retriever.Chunk]]: + r"""Return a callable for the create chunk method over gRPC. + + Creates a ``Chunk``. + + Returns: + Callable[[~.CreateChunkRequest], + Awaitable[~.Chunk]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "create_chunk" not in self._stubs: + self._stubs["create_chunk"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/CreateChunk", + request_serializer=retriever_service.CreateChunkRequest.serialize, + response_deserializer=retriever.Chunk.deserialize, + ) + return self._stubs["create_chunk"] + + @property + def batch_create_chunks( + self, + ) -> Callable[ + [retriever_service.BatchCreateChunksRequest], + Awaitable[retriever_service.BatchCreateChunksResponse], + ]: + r"""Return a callable for the batch create chunks method over gRPC. + + Batch create ``Chunk``\ s. + + Returns: + Callable[[~.BatchCreateChunksRequest], + Awaitable[~.BatchCreateChunksResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_create_chunks" not in self._stubs: + self._stubs["batch_create_chunks"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/BatchCreateChunks", + request_serializer=retriever_service.BatchCreateChunksRequest.serialize, + response_deserializer=retriever_service.BatchCreateChunksResponse.deserialize, + ) + return self._stubs["batch_create_chunks"] + + @property + def get_chunk( + self, + ) -> Callable[[retriever_service.GetChunkRequest], Awaitable[retriever.Chunk]]: + r"""Return a callable for the get chunk method over gRPC. + + Gets information about a specific ``Chunk``. + + Returns: + Callable[[~.GetChunkRequest], + Awaitable[~.Chunk]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_chunk" not in self._stubs: + self._stubs["get_chunk"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/GetChunk", + request_serializer=retriever_service.GetChunkRequest.serialize, + response_deserializer=retriever.Chunk.deserialize, + ) + return self._stubs["get_chunk"] + + @property + def update_chunk( + self, + ) -> Callable[[retriever_service.UpdateChunkRequest], Awaitable[retriever.Chunk]]: + r"""Return a callable for the update chunk method over gRPC. + + Updates a ``Chunk``. + + Returns: + Callable[[~.UpdateChunkRequest], + Awaitable[~.Chunk]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "update_chunk" not in self._stubs: + self._stubs["update_chunk"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/UpdateChunk", + request_serializer=retriever_service.UpdateChunkRequest.serialize, + response_deserializer=retriever.Chunk.deserialize, + ) + return self._stubs["update_chunk"] + + @property + def batch_update_chunks( + self, + ) -> Callable[ + [retriever_service.BatchUpdateChunksRequest], + Awaitable[retriever_service.BatchUpdateChunksResponse], + ]: + r"""Return a callable for the batch update chunks method over gRPC. + + Batch update ``Chunk``\ s. + + Returns: + Callable[[~.BatchUpdateChunksRequest], + Awaitable[~.BatchUpdateChunksResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_update_chunks" not in self._stubs: + self._stubs["batch_update_chunks"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/BatchUpdateChunks", + request_serializer=retriever_service.BatchUpdateChunksRequest.serialize, + response_deserializer=retriever_service.BatchUpdateChunksResponse.deserialize, + ) + return self._stubs["batch_update_chunks"] + + @property + def delete_chunk( + self, + ) -> Callable[[retriever_service.DeleteChunkRequest], Awaitable[empty_pb2.Empty]]: + r"""Return a callable for the delete chunk method over gRPC. + + Deletes a ``Chunk``. + + Returns: + Callable[[~.DeleteChunkRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "delete_chunk" not in self._stubs: + self._stubs["delete_chunk"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/DeleteChunk", + request_serializer=retriever_service.DeleteChunkRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["delete_chunk"] + + @property + def batch_delete_chunks( + self, + ) -> Callable[ + [retriever_service.BatchDeleteChunksRequest], Awaitable[empty_pb2.Empty] + ]: + r"""Return a callable for the batch delete chunks method over gRPC. + + Batch delete ``Chunk``\ s. + + Returns: + Callable[[~.BatchDeleteChunksRequest], + Awaitable[~.Empty]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_delete_chunks" not in self._stubs: + self._stubs["batch_delete_chunks"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/BatchDeleteChunks", + request_serializer=retriever_service.BatchDeleteChunksRequest.serialize, + response_deserializer=empty_pb2.Empty.FromString, + ) + return self._stubs["batch_delete_chunks"] + + @property + def list_chunks( + self, + ) -> Callable[ + [retriever_service.ListChunksRequest], + Awaitable[retriever_service.ListChunksResponse], + ]: + r"""Return a callable for the list chunks method over gRPC. + + Lists all ``Chunk``\ s in a ``Document``. + + Returns: + Callable[[~.ListChunksRequest], + Awaitable[~.ListChunksResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_chunks" not in self._stubs: + self._stubs["list_chunks"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.RetrieverService/ListChunks", + request_serializer=retriever_service.ListChunksRequest.serialize, + response_deserializer=retriever_service.ListChunksResponse.deserialize, + ) + return self._stubs["list_chunks"] + + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.create_corpus: self._wrap_method( + self.create_corpus, + default_timeout=None, + client_info=client_info, + ), + self.get_corpus: self._wrap_method( + self.get_corpus, + default_timeout=None, + client_info=client_info, + ), + self.update_corpus: self._wrap_method( + self.update_corpus, + default_timeout=None, + client_info=client_info, + ), + self.delete_corpus: self._wrap_method( + self.delete_corpus, + default_timeout=None, + client_info=client_info, + ), + self.list_corpora: self._wrap_method( + self.list_corpora, + default_timeout=None, + client_info=client_info, + ), + self.query_corpus: self._wrap_method( + self.query_corpus, + default_timeout=None, + client_info=client_info, + ), + self.create_document: self._wrap_method( + self.create_document, + default_timeout=None, + client_info=client_info, + ), + self.get_document: self._wrap_method( + self.get_document, + default_timeout=None, + client_info=client_info, + ), + self.update_document: self._wrap_method( + self.update_document, + default_timeout=None, + client_info=client_info, + ), + self.delete_document: self._wrap_method( + self.delete_document, + default_timeout=None, + client_info=client_info, + ), + self.list_documents: self._wrap_method( + self.list_documents, + default_timeout=None, + client_info=client_info, + ), + self.query_document: self._wrap_method( + self.query_document, + default_timeout=None, + client_info=client_info, + ), + self.create_chunk: self._wrap_method( + self.create_chunk, + default_timeout=None, + client_info=client_info, + ), + self.batch_create_chunks: self._wrap_method( + self.batch_create_chunks, + default_timeout=None, + client_info=client_info, + ), + self.get_chunk: self._wrap_method( + self.get_chunk, + default_timeout=None, + client_info=client_info, + ), + self.update_chunk: self._wrap_method( + self.update_chunk, + default_timeout=None, + client_info=client_info, + ), + self.batch_update_chunks: self._wrap_method( + self.batch_update_chunks, + default_timeout=None, + client_info=client_info, + ), + self.delete_chunk: self._wrap_method( + self.delete_chunk, + default_timeout=None, + client_info=client_info, + ), + self.batch_delete_chunks: self._wrap_method( + self.batch_delete_chunks, + default_timeout=None, + client_info=client_info, + ), + self.list_chunks: self._wrap_method( + self.list_chunks, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: self._wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: self._wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def _wrap_method(self, func, *args, **kwargs): + if self._wrap_with_kind: # pragma: NO COVER + kwargs["kind"] = self.kind + return gapic_v1.method_async.wrap_method(func, *args, **kwargs) + + def close(self): + return self._logged_channel.close() + + @property + def kind(self) -> str: + return "grpc_asyncio" + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + +__all__ = ("RetrieverServiceGrpcAsyncIOTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/rest.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/rest.py new file mode 100644 index 000000000000..a53b3ac01460 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/rest.py @@ -0,0 +1,4067 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +import json # type: ignore +import logging +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, rest_helpers, rest_streaming +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.requests import AuthorizedSession # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import json_format +from requests import __version__ as requests_version + +from google.ai.generativelanguage_v1alpha.types import retriever, retriever_service + +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from .rest_base import _BaseRetrieverServiceRestTransport + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = logging.getLogger(__name__) + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=f"requests@{requests_version}", +) + + +class RetrieverServiceRestInterceptor: + """Interceptor for RetrieverService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the RetrieverServiceRestTransport. + + .. code-block:: python + class MyCustomRetrieverServiceInterceptor(RetrieverServiceRestInterceptor): + def pre_batch_create_chunks(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_batch_create_chunks(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_batch_delete_chunks(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def pre_batch_update_chunks(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_batch_update_chunks(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_create_chunk(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_create_chunk(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_create_corpus(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_create_corpus(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_create_document(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_create_document(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_delete_chunk(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def pre_delete_corpus(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def pre_delete_document(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def pre_get_chunk(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_chunk(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_get_corpus(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_corpus(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_get_document(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_get_document(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_chunks(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_chunks(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_corpora(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_corpora(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_list_documents(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_list_documents(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_query_corpus(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_query_corpus(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_query_document(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_query_document(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_update_chunk(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_update_chunk(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_update_corpus(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_update_corpus(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_update_document(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_update_document(self, response): + logging.log(f"Received response: {response}") + return response + + transport = RetrieverServiceRestTransport(interceptor=MyCustomRetrieverServiceInterceptor()) + client = RetrieverServiceClient(transport=transport) + + + """ + + def pre_batch_create_chunks( + self, + request: retriever_service.BatchCreateChunksRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.BatchCreateChunksRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for batch_create_chunks + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_batch_create_chunks( + self, response: retriever_service.BatchCreateChunksResponse + ) -> retriever_service.BatchCreateChunksResponse: + """Post-rpc interceptor for batch_create_chunks + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_batch_delete_chunks( + self, + request: retriever_service.BatchDeleteChunksRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.BatchDeleteChunksRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for batch_delete_chunks + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def pre_batch_update_chunks( + self, + request: retriever_service.BatchUpdateChunksRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.BatchUpdateChunksRequest, + Sequence[Tuple[str, Union[str, bytes]]], + ]: + """Pre-rpc interceptor for batch_update_chunks + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_batch_update_chunks( + self, response: retriever_service.BatchUpdateChunksResponse + ) -> retriever_service.BatchUpdateChunksResponse: + """Post-rpc interceptor for batch_update_chunks + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_create_chunk( + self, + request: retriever_service.CreateChunkRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.CreateChunkRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for create_chunk + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_create_chunk(self, response: retriever.Chunk) -> retriever.Chunk: + """Post-rpc interceptor for create_chunk + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_create_corpus( + self, + request: retriever_service.CreateCorpusRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.CreateCorpusRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for create_corpus + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_create_corpus(self, response: retriever.Corpus) -> retriever.Corpus: + """Post-rpc interceptor for create_corpus + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_create_document( + self, + request: retriever_service.CreateDocumentRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.CreateDocumentRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for create_document + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_create_document(self, response: retriever.Document) -> retriever.Document: + """Post-rpc interceptor for create_document + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_delete_chunk( + self, + request: retriever_service.DeleteChunkRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.DeleteChunkRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for delete_chunk + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def pre_delete_corpus( + self, + request: retriever_service.DeleteCorpusRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.DeleteCorpusRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for delete_corpus + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def pre_delete_document( + self, + request: retriever_service.DeleteDocumentRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.DeleteDocumentRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for delete_document + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def pre_get_chunk( + self, + request: retriever_service.GetChunkRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.GetChunkRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for get_chunk + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_get_chunk(self, response: retriever.Chunk) -> retriever.Chunk: + """Post-rpc interceptor for get_chunk + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_get_corpus( + self, + request: retriever_service.GetCorpusRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.GetCorpusRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for get_corpus + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_get_corpus(self, response: retriever.Corpus) -> retriever.Corpus: + """Post-rpc interceptor for get_corpus + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_get_document( + self, + request: retriever_service.GetDocumentRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.GetDocumentRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for get_document + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_get_document(self, response: retriever.Document) -> retriever.Document: + """Post-rpc interceptor for get_document + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_list_chunks( + self, + request: retriever_service.ListChunksRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.ListChunksRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for list_chunks + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_list_chunks( + self, response: retriever_service.ListChunksResponse + ) -> retriever_service.ListChunksResponse: + """Post-rpc interceptor for list_chunks + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_list_corpora( + self, + request: retriever_service.ListCorporaRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.ListCorporaRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for list_corpora + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_list_corpora( + self, response: retriever_service.ListCorporaResponse + ) -> retriever_service.ListCorporaResponse: + """Post-rpc interceptor for list_corpora + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_list_documents( + self, + request: retriever_service.ListDocumentsRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.ListDocumentsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for list_documents + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_list_documents( + self, response: retriever_service.ListDocumentsResponse + ) -> retriever_service.ListDocumentsResponse: + """Post-rpc interceptor for list_documents + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_query_corpus( + self, + request: retriever_service.QueryCorpusRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.QueryCorpusRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for query_corpus + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_query_corpus( + self, response: retriever_service.QueryCorpusResponse + ) -> retriever_service.QueryCorpusResponse: + """Post-rpc interceptor for query_corpus + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_query_document( + self, + request: retriever_service.QueryDocumentRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.QueryDocumentRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for query_document + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_query_document( + self, response: retriever_service.QueryDocumentResponse + ) -> retriever_service.QueryDocumentResponse: + """Post-rpc interceptor for query_document + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_update_chunk( + self, + request: retriever_service.UpdateChunkRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.UpdateChunkRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for update_chunk + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_update_chunk(self, response: retriever.Chunk) -> retriever.Chunk: + """Post-rpc interceptor for update_chunk + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_update_corpus( + self, + request: retriever_service.UpdateCorpusRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.UpdateCorpusRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for update_corpus + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_update_corpus(self, response: retriever.Corpus) -> retriever.Corpus: + """Post-rpc interceptor for update_corpus + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_update_document( + self, + request: retriever_service.UpdateDocumentRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + retriever_service.UpdateDocumentRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for update_document + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_update_document(self, response: retriever.Document) -> retriever.Document: + """Post-rpc interceptor for update_document + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_get_operation( + self, + request: operations_pb2.GetOperationRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.GetOperationRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for get_operation + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_get_operation( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for get_operation + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + def pre_list_operations( + self, + request: operations_pb2.ListOperationsRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.ListOperationsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for list_operations + + Override in a subclass to manipulate the request or metadata + before they are sent to the RetrieverService server. + """ + return request, metadata + + def post_list_operations( + self, response: operations_pb2.ListOperationsResponse + ) -> operations_pb2.ListOperationsResponse: + """Post-rpc interceptor for list_operations + + Override in a subclass to manipulate the response + after it is returned by the RetrieverService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class RetrieverServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: RetrieverServiceRestInterceptor + + +class RetrieverServiceRestTransport(_BaseRetrieverServiceRestTransport): + """REST backend synchronous transport for RetrieverService. + + An API for semantic search over a corpus of user uploaded + content. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[RetrieverServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + url_scheme=url_scheme, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or RetrieverServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _BatchCreateChunks( + _BaseRetrieverServiceRestTransport._BaseBatchCreateChunks, + RetrieverServiceRestStub, + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.BatchCreateChunks") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: retriever_service.BatchCreateChunksRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever_service.BatchCreateChunksResponse: + r"""Call the batch create chunks method over HTTP. + + Args: + request (~.retriever_service.BatchCreateChunksRequest): + The request object. Request to batch create ``Chunk``\ s. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever_service.BatchCreateChunksResponse: + Response from ``BatchCreateChunks`` containing a list of + created ``Chunk``\ s. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseBatchCreateChunks._get_http_options() + ) + + request, metadata = self._interceptor.pre_batch_create_chunks( + request, metadata + ) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseBatchCreateChunks._get_transcoded_request( + http_options, request + ) + + body = _BaseRetrieverServiceRestTransport._BaseBatchCreateChunks._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseBatchCreateChunks._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.BatchCreateChunks", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "BatchCreateChunks", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._BatchCreateChunks._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever_service.BatchCreateChunksResponse() + pb_resp = retriever_service.BatchCreateChunksResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_batch_create_chunks(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = ( + retriever_service.BatchCreateChunksResponse.to_json(response) + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.batch_create_chunks", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "BatchCreateChunks", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _BatchDeleteChunks( + _BaseRetrieverServiceRestTransport._BaseBatchDeleteChunks, + RetrieverServiceRestStub, + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.BatchDeleteChunks") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: retriever_service.BatchDeleteChunksRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ): + r"""Call the batch delete chunks method over HTTP. + + Args: + request (~.retriever_service.BatchDeleteChunksRequest): + The request object. Request to batch delete ``Chunk``\ s. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseBatchDeleteChunks._get_http_options() + ) + + request, metadata = self._interceptor.pre_batch_delete_chunks( + request, metadata + ) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseBatchDeleteChunks._get_transcoded_request( + http_options, request + ) + + body = _BaseRetrieverServiceRestTransport._BaseBatchDeleteChunks._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseBatchDeleteChunks._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.BatchDeleteChunks", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "BatchDeleteChunks", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._BatchDeleteChunks._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + class _BatchUpdateChunks( + _BaseRetrieverServiceRestTransport._BaseBatchUpdateChunks, + RetrieverServiceRestStub, + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.BatchUpdateChunks") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: retriever_service.BatchUpdateChunksRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever_service.BatchUpdateChunksResponse: + r"""Call the batch update chunks method over HTTP. + + Args: + request (~.retriever_service.BatchUpdateChunksRequest): + The request object. Request to batch update ``Chunk``\ s. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever_service.BatchUpdateChunksResponse: + Response from ``BatchUpdateChunks`` containing a list of + updated ``Chunk``\ s. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseBatchUpdateChunks._get_http_options() + ) + + request, metadata = self._interceptor.pre_batch_update_chunks( + request, metadata + ) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseBatchUpdateChunks._get_transcoded_request( + http_options, request + ) + + body = _BaseRetrieverServiceRestTransport._BaseBatchUpdateChunks._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseBatchUpdateChunks._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.BatchUpdateChunks", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "BatchUpdateChunks", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._BatchUpdateChunks._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever_service.BatchUpdateChunksResponse() + pb_resp = retriever_service.BatchUpdateChunksResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_batch_update_chunks(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = ( + retriever_service.BatchUpdateChunksResponse.to_json(response) + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.batch_update_chunks", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "BatchUpdateChunks", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _CreateChunk( + _BaseRetrieverServiceRestTransport._BaseCreateChunk, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.CreateChunk") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: retriever_service.CreateChunkRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Chunk: + r"""Call the create chunk method over HTTP. + + Args: + request (~.retriever_service.CreateChunkRequest): + The request object. Request to create a ``Chunk``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever.Chunk: + A ``Chunk`` is a subpart of a ``Document`` that is + treated as an independent unit for the purposes of + vector representation and storage. A ``Corpus`` can have + a maximum of 1 million ``Chunk``\ s. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseCreateChunk._get_http_options() + ) + + request, metadata = self._interceptor.pre_create_chunk(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseCreateChunk._get_transcoded_request( + http_options, request + ) + + body = _BaseRetrieverServiceRestTransport._BaseCreateChunk._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseCreateChunk._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.CreateChunk", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "CreateChunk", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._CreateChunk._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever.Chunk() + pb_resp = retriever.Chunk.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_create_chunk(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = retriever.Chunk.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.create_chunk", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "CreateChunk", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _CreateCorpus( + _BaseRetrieverServiceRestTransport._BaseCreateCorpus, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.CreateCorpus") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: retriever_service.CreateCorpusRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Corpus: + r"""Call the create corpus method over HTTP. + + Args: + request (~.retriever_service.CreateCorpusRequest): + The request object. Request to create a ``Corpus``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever.Corpus: + A ``Corpus`` is a collection of ``Document``\ s. A + project can create up to 5 corpora. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseCreateCorpus._get_http_options() + ) + + request, metadata = self._interceptor.pre_create_corpus(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseCreateCorpus._get_transcoded_request( + http_options, request + ) + + body = _BaseRetrieverServiceRestTransport._BaseCreateCorpus._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseCreateCorpus._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.CreateCorpus", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "CreateCorpus", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._CreateCorpus._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever.Corpus() + pb_resp = retriever.Corpus.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_create_corpus(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = retriever.Corpus.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.create_corpus", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "CreateCorpus", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _CreateDocument( + _BaseRetrieverServiceRestTransport._BaseCreateDocument, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.CreateDocument") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: retriever_service.CreateDocumentRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Document: + r"""Call the create document method over HTTP. + + Args: + request (~.retriever_service.CreateDocumentRequest): + The request object. Request to create a ``Document``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever.Document: + A ``Document`` is a collection of ``Chunk``\ s. A + ``Corpus`` can have a maximum of 10,000 ``Document``\ s. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseCreateDocument._get_http_options() + ) + + request, metadata = self._interceptor.pre_create_document(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseCreateDocument._get_transcoded_request( + http_options, request + ) + + body = _BaseRetrieverServiceRestTransport._BaseCreateDocument._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseCreateDocument._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.CreateDocument", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "CreateDocument", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._CreateDocument._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever.Document() + pb_resp = retriever.Document.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_create_document(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = retriever.Document.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.create_document", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "CreateDocument", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _DeleteChunk( + _BaseRetrieverServiceRestTransport._BaseDeleteChunk, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.DeleteChunk") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: retriever_service.DeleteChunkRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ): + r"""Call the delete chunk method over HTTP. + + Args: + request (~.retriever_service.DeleteChunkRequest): + The request object. Request to delete a ``Chunk``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseDeleteChunk._get_http_options() + ) + + request, metadata = self._interceptor.pre_delete_chunk(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseDeleteChunk._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseDeleteChunk._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.DeleteChunk", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "DeleteChunk", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._DeleteChunk._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + class _DeleteCorpus( + _BaseRetrieverServiceRestTransport._BaseDeleteCorpus, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.DeleteCorpus") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: retriever_service.DeleteCorpusRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ): + r"""Call the delete corpus method over HTTP. + + Args: + request (~.retriever_service.DeleteCorpusRequest): + The request object. Request to delete a ``Corpus``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseDeleteCorpus._get_http_options() + ) + + request, metadata = self._interceptor.pre_delete_corpus(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseDeleteCorpus._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseDeleteCorpus._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.DeleteCorpus", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "DeleteCorpus", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._DeleteCorpus._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + class _DeleteDocument( + _BaseRetrieverServiceRestTransport._BaseDeleteDocument, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.DeleteDocument") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: retriever_service.DeleteDocumentRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ): + r"""Call the delete document method over HTTP. + + Args: + request (~.retriever_service.DeleteDocumentRequest): + The request object. Request to delete a ``Document``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseDeleteDocument._get_http_options() + ) + + request, metadata = self._interceptor.pre_delete_document(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseDeleteDocument._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseDeleteDocument._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.DeleteDocument", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "DeleteDocument", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._DeleteDocument._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + class _GetChunk( + _BaseRetrieverServiceRestTransport._BaseGetChunk, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.GetChunk") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: retriever_service.GetChunkRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Chunk: + r"""Call the get chunk method over HTTP. + + Args: + request (~.retriever_service.GetChunkRequest): + The request object. Request for getting information about a specific + ``Chunk``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever.Chunk: + A ``Chunk`` is a subpart of a ``Document`` that is + treated as an independent unit for the purposes of + vector representation and storage. A ``Corpus`` can have + a maximum of 1 million ``Chunk``\ s. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseGetChunk._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_chunk(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseGetChunk._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = ( + _BaseRetrieverServiceRestTransport._BaseGetChunk._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.GetChunk", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "GetChunk", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._GetChunk._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever.Chunk() + pb_resp = retriever.Chunk.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_get_chunk(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = retriever.Chunk.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.get_chunk", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "GetChunk", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _GetCorpus( + _BaseRetrieverServiceRestTransport._BaseGetCorpus, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.GetCorpus") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: retriever_service.GetCorpusRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Corpus: + r"""Call the get corpus method over HTTP. + + Args: + request (~.retriever_service.GetCorpusRequest): + The request object. Request for getting information about a specific + ``Corpus``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever.Corpus: + A ``Corpus`` is a collection of ``Document``\ s. A + project can create up to 5 corpora. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseGetCorpus._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_corpus(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseGetCorpus._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseGetCorpus._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.GetCorpus", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "GetCorpus", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._GetCorpus._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever.Corpus() + pb_resp = retriever.Corpus.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_get_corpus(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = retriever.Corpus.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.get_corpus", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "GetCorpus", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _GetDocument( + _BaseRetrieverServiceRestTransport._BaseGetDocument, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.GetDocument") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: retriever_service.GetDocumentRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Document: + r"""Call the get document method over HTTP. + + Args: + request (~.retriever_service.GetDocumentRequest): + The request object. Request for getting information about a specific + ``Document``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever.Document: + A ``Document`` is a collection of ``Chunk``\ s. A + ``Corpus`` can have a maximum of 10,000 ``Document``\ s. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseGetDocument._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_document(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseGetDocument._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseGetDocument._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.GetDocument", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "GetDocument", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._GetDocument._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever.Document() + pb_resp = retriever.Document.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_get_document(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = retriever.Document.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.get_document", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "GetDocument", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _ListChunks( + _BaseRetrieverServiceRestTransport._BaseListChunks, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.ListChunks") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: retriever_service.ListChunksRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever_service.ListChunksResponse: + r"""Call the list chunks method over HTTP. + + Args: + request (~.retriever_service.ListChunksRequest): + The request object. Request for listing ``Chunk``\ s. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever_service.ListChunksResponse: + Response from ``ListChunks`` containing a paginated list + of ``Chunk``\ s. The ``Chunk``\ s are sorted by + ascending ``chunk.create_time``. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseListChunks._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_chunks(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseListChunks._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseListChunks._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.ListChunks", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "ListChunks", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._ListChunks._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever_service.ListChunksResponse() + pb_resp = retriever_service.ListChunksResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_list_chunks(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = retriever_service.ListChunksResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.list_chunks", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "ListChunks", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _ListCorpora( + _BaseRetrieverServiceRestTransport._BaseListCorpora, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.ListCorpora") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: retriever_service.ListCorporaRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever_service.ListCorporaResponse: + r"""Call the list corpora method over HTTP. + + Args: + request (~.retriever_service.ListCorporaRequest): + The request object. Request for listing ``Corpora``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever_service.ListCorporaResponse: + Response from ``ListCorpora`` containing a paginated + list of ``Corpora``. The results are sorted by ascending + ``corpus.create_time``. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseListCorpora._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_corpora(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseListCorpora._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseListCorpora._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.ListCorpora", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "ListCorpora", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._ListCorpora._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever_service.ListCorporaResponse() + pb_resp = retriever_service.ListCorporaResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_list_corpora(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = retriever_service.ListCorporaResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.list_corpora", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "ListCorpora", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _ListDocuments( + _BaseRetrieverServiceRestTransport._BaseListDocuments, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.ListDocuments") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: retriever_service.ListDocumentsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever_service.ListDocumentsResponse: + r"""Call the list documents method over HTTP. + + Args: + request (~.retriever_service.ListDocumentsRequest): + The request object. Request for listing ``Document``\ s. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever_service.ListDocumentsResponse: + Response from ``ListDocuments`` containing a paginated + list of ``Document``\ s. The ``Document``\ s are sorted + by ascending ``document.create_time``. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseListDocuments._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_documents(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseListDocuments._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseListDocuments._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.ListDocuments", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "ListDocuments", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._ListDocuments._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever_service.ListDocumentsResponse() + pb_resp = retriever_service.ListDocumentsResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_list_documents(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = retriever_service.ListDocumentsResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.list_documents", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "ListDocuments", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _QueryCorpus( + _BaseRetrieverServiceRestTransport._BaseQueryCorpus, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.QueryCorpus") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: retriever_service.QueryCorpusRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever_service.QueryCorpusResponse: + r"""Call the query corpus method over HTTP. + + Args: + request (~.retriever_service.QueryCorpusRequest): + The request object. Request for querying a ``Corpus``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever_service.QueryCorpusResponse: + Response from ``QueryCorpus`` containing a list of + relevant chunks. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseQueryCorpus._get_http_options() + ) + + request, metadata = self._interceptor.pre_query_corpus(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseQueryCorpus._get_transcoded_request( + http_options, request + ) + + body = _BaseRetrieverServiceRestTransport._BaseQueryCorpus._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseQueryCorpus._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.QueryCorpus", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "QueryCorpus", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._QueryCorpus._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever_service.QueryCorpusResponse() + pb_resp = retriever_service.QueryCorpusResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_query_corpus(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = retriever_service.QueryCorpusResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.query_corpus", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "QueryCorpus", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _QueryDocument( + _BaseRetrieverServiceRestTransport._BaseQueryDocument, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.QueryDocument") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: retriever_service.QueryDocumentRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever_service.QueryDocumentResponse: + r"""Call the query document method over HTTP. + + Args: + request (~.retriever_service.QueryDocumentRequest): + The request object. Request for querying a ``Document``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever_service.QueryDocumentResponse: + Response from ``QueryDocument`` containing a list of + relevant chunks. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseQueryDocument._get_http_options() + ) + + request, metadata = self._interceptor.pre_query_document(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseQueryDocument._get_transcoded_request( + http_options, request + ) + + body = _BaseRetrieverServiceRestTransport._BaseQueryDocument._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseQueryDocument._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.QueryDocument", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "QueryDocument", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._QueryDocument._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever_service.QueryDocumentResponse() + pb_resp = retriever_service.QueryDocumentResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_query_document(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = retriever_service.QueryDocumentResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.query_document", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "QueryDocument", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _UpdateChunk( + _BaseRetrieverServiceRestTransport._BaseUpdateChunk, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.UpdateChunk") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: retriever_service.UpdateChunkRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Chunk: + r"""Call the update chunk method over HTTP. + + Args: + request (~.retriever_service.UpdateChunkRequest): + The request object. Request to update a ``Chunk``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever.Chunk: + A ``Chunk`` is a subpart of a ``Document`` that is + treated as an independent unit for the purposes of + vector representation and storage. A ``Corpus`` can have + a maximum of 1 million ``Chunk``\ s. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseUpdateChunk._get_http_options() + ) + + request, metadata = self._interceptor.pre_update_chunk(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseUpdateChunk._get_transcoded_request( + http_options, request + ) + + body = _BaseRetrieverServiceRestTransport._BaseUpdateChunk._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseUpdateChunk._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.UpdateChunk", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "UpdateChunk", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._UpdateChunk._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever.Chunk() + pb_resp = retriever.Chunk.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_update_chunk(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = retriever.Chunk.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.update_chunk", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "UpdateChunk", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _UpdateCorpus( + _BaseRetrieverServiceRestTransport._BaseUpdateCorpus, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.UpdateCorpus") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: retriever_service.UpdateCorpusRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Corpus: + r"""Call the update corpus method over HTTP. + + Args: + request (~.retriever_service.UpdateCorpusRequest): + The request object. Request to update a ``Corpus``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever.Corpus: + A ``Corpus`` is a collection of ``Document``\ s. A + project can create up to 5 corpora. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseUpdateCorpus._get_http_options() + ) + + request, metadata = self._interceptor.pre_update_corpus(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseUpdateCorpus._get_transcoded_request( + http_options, request + ) + + body = _BaseRetrieverServiceRestTransport._BaseUpdateCorpus._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseUpdateCorpus._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.UpdateCorpus", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "UpdateCorpus", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._UpdateCorpus._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever.Corpus() + pb_resp = retriever.Corpus.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_update_corpus(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = retriever.Corpus.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.update_corpus", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "UpdateCorpus", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _UpdateDocument( + _BaseRetrieverServiceRestTransport._BaseUpdateDocument, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.UpdateDocument") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: retriever_service.UpdateDocumentRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> retriever.Document: + r"""Call the update document method over HTTP. + + Args: + request (~.retriever_service.UpdateDocumentRequest): + The request object. Request to update a ``Document``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.retriever.Document: + A ``Document`` is a collection of ``Chunk``\ s. A + ``Corpus`` can have a maximum of 10,000 ``Document``\ s. + + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseUpdateDocument._get_http_options() + ) + + request, metadata = self._interceptor.pre_update_document(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseUpdateDocument._get_transcoded_request( + http_options, request + ) + + body = _BaseRetrieverServiceRestTransport._BaseUpdateDocument._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseUpdateDocument._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.UpdateDocument", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "UpdateDocument", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._UpdateDocument._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = retriever.Document() + pb_resp = retriever.Document.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_update_document(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = retriever.Document.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.update_document", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "UpdateDocument", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + @property + def batch_create_chunks( + self, + ) -> Callable[ + [retriever_service.BatchCreateChunksRequest], + retriever_service.BatchCreateChunksResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._BatchCreateChunks(self._session, self._host, self._interceptor) # type: ignore + + @property + def batch_delete_chunks( + self, + ) -> Callable[[retriever_service.BatchDeleteChunksRequest], empty_pb2.Empty]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._BatchDeleteChunks(self._session, self._host, self._interceptor) # type: ignore + + @property + def batch_update_chunks( + self, + ) -> Callable[ + [retriever_service.BatchUpdateChunksRequest], + retriever_service.BatchUpdateChunksResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._BatchUpdateChunks(self._session, self._host, self._interceptor) # type: ignore + + @property + def create_chunk( + self, + ) -> Callable[[retriever_service.CreateChunkRequest], retriever.Chunk]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CreateChunk(self._session, self._host, self._interceptor) # type: ignore + + @property + def create_corpus( + self, + ) -> Callable[[retriever_service.CreateCorpusRequest], retriever.Corpus]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CreateCorpus(self._session, self._host, self._interceptor) # type: ignore + + @property + def create_document( + self, + ) -> Callable[[retriever_service.CreateDocumentRequest], retriever.Document]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CreateDocument(self._session, self._host, self._interceptor) # type: ignore + + @property + def delete_chunk( + self, + ) -> Callable[[retriever_service.DeleteChunkRequest], empty_pb2.Empty]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DeleteChunk(self._session, self._host, self._interceptor) # type: ignore + + @property + def delete_corpus( + self, + ) -> Callable[[retriever_service.DeleteCorpusRequest], empty_pb2.Empty]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DeleteCorpus(self._session, self._host, self._interceptor) # type: ignore + + @property + def delete_document( + self, + ) -> Callable[[retriever_service.DeleteDocumentRequest], empty_pb2.Empty]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._DeleteDocument(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_chunk( + self, + ) -> Callable[[retriever_service.GetChunkRequest], retriever.Chunk]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetChunk(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_corpus( + self, + ) -> Callable[[retriever_service.GetCorpusRequest], retriever.Corpus]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetCorpus(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_document( + self, + ) -> Callable[[retriever_service.GetDocumentRequest], retriever.Document]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GetDocument(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_chunks( + self, + ) -> Callable[ + [retriever_service.ListChunksRequest], retriever_service.ListChunksResponse + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListChunks(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_corpora( + self, + ) -> Callable[ + [retriever_service.ListCorporaRequest], retriever_service.ListCorporaResponse + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListCorpora(self._session, self._host, self._interceptor) # type: ignore + + @property + def list_documents( + self, + ) -> Callable[ + [retriever_service.ListDocumentsRequest], + retriever_service.ListDocumentsResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._ListDocuments(self._session, self._host, self._interceptor) # type: ignore + + @property + def query_corpus( + self, + ) -> Callable[ + [retriever_service.QueryCorpusRequest], retriever_service.QueryCorpusResponse + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._QueryCorpus(self._session, self._host, self._interceptor) # type: ignore + + @property + def query_document( + self, + ) -> Callable[ + [retriever_service.QueryDocumentRequest], + retriever_service.QueryDocumentResponse, + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._QueryDocument(self._session, self._host, self._interceptor) # type: ignore + + @property + def update_chunk( + self, + ) -> Callable[[retriever_service.UpdateChunkRequest], retriever.Chunk]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._UpdateChunk(self._session, self._host, self._interceptor) # type: ignore + + @property + def update_corpus( + self, + ) -> Callable[[retriever_service.UpdateCorpusRequest], retriever.Corpus]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._UpdateCorpus(self._session, self._host, self._interceptor) # type: ignore + + @property + def update_document( + self, + ) -> Callable[[retriever_service.UpdateDocumentRequest], retriever.Document]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._UpdateDocument(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_operation(self): + return self._GetOperation(self._session, self._host, self._interceptor) # type: ignore + + class _GetOperation( + _BaseRetrieverServiceRestTransport._BaseGetOperation, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.GetOperation") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.GetOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Call the get operation method over HTTP. + + Args: + request (operations_pb2.GetOperationRequest): + The request object for GetOperation method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.Operation: Response from GetOperation method. + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseGetOperation._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_operation(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseGetOperation._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseGetOperation._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "GetOperation", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._GetOperation._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.Operation() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_get_operation(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "GetOperation", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def list_operations(self): + return self._ListOperations(self._session, self._host, self._interceptor) # type: ignore + + class _ListOperations( + _BaseRetrieverServiceRestTransport._BaseListOperations, RetrieverServiceRestStub + ): + def __hash__(self): + return hash("RetrieverServiceRestTransport.ListOperations") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.ListOperationsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Call the list operations method over HTTP. + + Args: + request (operations_pb2.ListOperationsRequest): + The request object for ListOperations method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.ListOperationsResponse: Response from ListOperations method. + """ + + http_options = ( + _BaseRetrieverServiceRestTransport._BaseListOperations._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_operations(request, metadata) + transcoded_request = _BaseRetrieverServiceRestTransport._BaseListOperations._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseRetrieverServiceRestTransport._BaseListOperations._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.RetrieverServiceClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "ListOperations", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = RetrieverServiceRestTransport._ListOperations._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.ListOperationsResponse() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_list_operations(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "rpcName": "ListOperations", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("RetrieverServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/rest_base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/rest_base.py new file mode 100644 index 000000000000..b5fb721e220f --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/retriever_service/transports/rest_base.py @@ -0,0 +1,1196 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json # type: ignore +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1, path_template +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import json_format + +from google.ai.generativelanguage_v1alpha.types import retriever, retriever_service + +from .base import DEFAULT_CLIENT_INFO, RetrieverServiceTransport + + +class _BaseRetrieverServiceRestTransport(RetrieverServiceTransport): + """Base REST backend transport for RetrieverService. + + Note: This class is not meant to be used directly. Use its sync and + async sub-classes instead. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[Any] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[Any]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + class _BaseBatchCreateChunks: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{parent=corpora/*/documents/*}/chunks:batchCreate", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.BatchCreateChunksRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseBatchCreateChunks._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseBatchDeleteChunks: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{parent=corpora/*/documents/*}/chunks:batchDelete", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.BatchDeleteChunksRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseBatchDeleteChunks._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseBatchUpdateChunks: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{parent=corpora/*/documents/*}/chunks:batchUpdate", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.BatchUpdateChunksRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseBatchUpdateChunks._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseCreateChunk: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{parent=corpora/*/documents/*}/chunks", + "body": "chunk", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.CreateChunkRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseCreateChunk._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseCreateCorpus: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/corpora", + "body": "corpus", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.CreateCorpusRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseCreateCorpus._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseCreateDocument: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{parent=corpora/*}/documents", + "body": "document", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.CreateDocumentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseCreateDocument._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseDeleteChunk: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1alpha/{name=corpora/*/documents/*/chunks/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.DeleteChunkRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseDeleteChunk._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseDeleteCorpus: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1alpha/{name=corpora/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.DeleteCorpusRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseDeleteCorpus._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseDeleteDocument: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "delete", + "uri": "/v1alpha/{name=corpora/*/documents/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.DeleteDocumentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseDeleteDocument._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetChunk: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=corpora/*/documents/*/chunks/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.GetChunkRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseGetChunk._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetCorpus: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=corpora/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.GetCorpusRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseGetCorpus._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetDocument: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=corpora/*/documents/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.GetDocumentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseGetDocument._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseListChunks: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{parent=corpora/*/documents/*}/chunks", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.ListChunksRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseListChunks._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseListCorpora: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/corpora", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.ListCorporaRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseListDocuments: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{parent=corpora/*}/documents", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.ListDocumentsRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseListDocuments._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseQueryCorpus: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{name=corpora/*}:query", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.QueryCorpusRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseQueryCorpus._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseQueryDocument: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{name=corpora/*/documents/*}:query", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.QueryDocumentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseQueryDocument._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseUpdateChunk: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + "updateMask": {}, + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "patch", + "uri": "/v1alpha/{chunk.name=corpora/*/documents/*/chunks/*}", + "body": "chunk", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.UpdateChunkRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseUpdateChunk._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseUpdateCorpus: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + "updateMask": {}, + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "patch", + "uri": "/v1alpha/{corpus.name=corpora/*}", + "body": "corpus", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.UpdateCorpusRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseUpdateCorpus._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseUpdateDocument: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = { + "updateMask": {}, + } + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "patch", + "uri": "/v1alpha/{document.name=corpora/*/documents/*}", + "body": "document", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = retriever_service.UpdateDocumentRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseRetrieverServiceRestTransport._BaseUpdateDocument._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetOperation: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=generatedFiles/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*/operations/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + class _BaseListOperations: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*}/operations", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*}/operations", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + +__all__ = ("_BaseRetrieverServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/__init__.py new file mode 100644 index 000000000000..56818276d243 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/__init__.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .async_client import TextServiceAsyncClient +from .client import TextServiceClient + +__all__ = ( + "TextServiceClient", + "TextServiceAsyncClient", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/async_client.py new file mode 100644 index 000000000000..f9dd0d629eba --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/async_client.py @@ -0,0 +1,1005 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry_async as retries +from google.api_core.client_options import ClientOptions +from google.auth import credentials as ga_credentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.AsyncRetry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.AsyncRetry, object, None] # type: ignore + +from google.longrunning import operations_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.types import safety, text_service + +from .client import TextServiceClient +from .transports.base import DEFAULT_CLIENT_INFO, TextServiceTransport +from .transports.grpc_asyncio import TextServiceGrpcAsyncIOTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class TextServiceAsyncClient: + """API for using Generative Language Models (GLMs) trained to + generate text. + Also known as Large Language Models (LLM)s, these generate text + given an input prompt from the user. + """ + + _client: TextServiceClient + + # Copy defaults from the synchronous client for use here. + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = TextServiceClient.DEFAULT_ENDPOINT + DEFAULT_MTLS_ENDPOINT = TextServiceClient.DEFAULT_MTLS_ENDPOINT + _DEFAULT_ENDPOINT_TEMPLATE = TextServiceClient._DEFAULT_ENDPOINT_TEMPLATE + _DEFAULT_UNIVERSE = TextServiceClient._DEFAULT_UNIVERSE + + model_path = staticmethod(TextServiceClient.model_path) + parse_model_path = staticmethod(TextServiceClient.parse_model_path) + common_billing_account_path = staticmethod( + TextServiceClient.common_billing_account_path + ) + parse_common_billing_account_path = staticmethod( + TextServiceClient.parse_common_billing_account_path + ) + common_folder_path = staticmethod(TextServiceClient.common_folder_path) + parse_common_folder_path = staticmethod(TextServiceClient.parse_common_folder_path) + common_organization_path = staticmethod(TextServiceClient.common_organization_path) + parse_common_organization_path = staticmethod( + TextServiceClient.parse_common_organization_path + ) + common_project_path = staticmethod(TextServiceClient.common_project_path) + parse_common_project_path = staticmethod( + TextServiceClient.parse_common_project_path + ) + common_location_path = staticmethod(TextServiceClient.common_location_path) + parse_common_location_path = staticmethod( + TextServiceClient.parse_common_location_path + ) + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TextServiceAsyncClient: The constructed client. + """ + return TextServiceClient.from_service_account_info.__func__(TextServiceAsyncClient, info, *args, **kwargs) # type: ignore + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TextServiceAsyncClient: The constructed client. + """ + return TextServiceClient.from_service_account_file.__func__(TextServiceAsyncClient, filename, *args, **kwargs) # type: ignore + + from_service_account_json = from_service_account_file + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[ClientOptions] = None + ): + """Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + return TextServiceClient.get_mtls_endpoint_and_cert_source(client_options) # type: ignore + + @property + def transport(self) -> TextServiceTransport: + """Returns the transport used by the client instance. + + Returns: + TextServiceTransport: The transport used by the client instance. + """ + return self._client.transport + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._client._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used + by the client instance. + """ + return self._client._universe_domain + + get_transport_class = TextServiceClient.get_transport_class + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[str, TextServiceTransport, Callable[..., TextServiceTransport]] + ] = "grpc_asyncio", + client_options: Optional[ClientOptions] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the text service async client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,TextServiceTransport,Callable[..., TextServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport to use. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the TextServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client = TextServiceClient( + credentials=credentials, + transport=transport, + client_options=client_options, + client_info=client_info, + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.TextServiceAsyncClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "universeDomain": getattr( + self._client._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._client._transport._credentials).__module__}.{type(self._client._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._client._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "credentialsType": None, + }, + ) + + async def generate_text( + self, + request: Optional[Union[text_service.GenerateTextRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[text_service.TextPrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + max_output_tokens: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> text_service.GenerateTextResponse: + r"""Generates a response from the model given an input + message. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_generate_text(): + # Create a client + client = generativelanguage_v1alpha.TextServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1alpha.GenerateTextRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.generate_text(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.GenerateTextRequest, dict]]): + The request object. Request to generate a text completion + response from the model. + model (:class:`str`): + Required. The name of the ``Model`` or ``TunedModel`` to + use for generating the completion. Examples: + models/text-bison-001 + tunedModels/sentence-translator-u3b7m + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (:class:`google.ai.generativelanguage_v1alpha.types.TextPrompt`): + Required. The free-form input text + given to the model as a prompt. + Given a prompt, the model will generate + a TextCompletion response it predicts as + the completion of the input text. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + temperature (:class:`float`): + Optional. Controls the randomness of the output. Note: + The default value varies by model, see the + ``Model.temperature`` attribute of the ``Model`` + returned the ``getModel`` function. + + Values can range from [0.0,1.0], inclusive. A value + closer to 1.0 will produce responses that are more + varied and creative, while a value closer to 0.0 will + typically result in more straightforward responses from + the model. + + This corresponds to the ``temperature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + candidate_count (:class:`int`): + Optional. Number of generated responses to return. + + This value must be between [1, 8], inclusive. If unset, + this will default to 1. + + This corresponds to the ``candidate_count`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + max_output_tokens (:class:`int`): + Optional. The maximum number of tokens to include in a + candidate. + + If unset, this will default to output_token_limit + specified in the ``Model`` specification. + + This corresponds to the ``max_output_tokens`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_p (:class:`float`): + Optional. The maximum cumulative probability of tokens + to consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Tokens are sorted based on their assigned probabilities + so that only the most likely tokens are considered. + Top-k sampling directly limits the maximum number of + tokens to consider, while Nucleus sampling limits number + of tokens based on the cumulative probability. + + Note: The default value varies by model, see the + ``Model.top_p`` attribute of the ``Model`` returned the + ``getModel`` function. + + This corresponds to the ``top_p`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_k (:class:`int`): + Optional. The maximum number of tokens to consider when + sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most + probable tokens. Defaults to 40. + + Note: The default value varies by model, see the + ``Model.top_k`` attribute of the ``Model`` returned the + ``getModel`` function. + + This corresponds to the ``top_k`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.GenerateTextResponse: + The response from the model, + including candidate completions. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any( + [ + model, + prompt, + temperature, + candidate_count, + max_output_tokens, + top_p, + top_k, + ] + ) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, text_service.GenerateTextRequest): + request = text_service.GenerateTextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + if temperature is not None: + request.temperature = temperature + if candidate_count is not None: + request.candidate_count = candidate_count + if max_output_tokens is not None: + request.max_output_tokens = max_output_tokens + if top_p is not None: + request.top_p = top_p + if top_k is not None: + request.top_k = top_k + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.generate_text + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def embed_text( + self, + request: Optional[Union[text_service.EmbedTextRequest, dict]] = None, + *, + model: Optional[str] = None, + text: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> text_service.EmbedTextResponse: + r"""Generates an embedding from the model given an input + message. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_embed_text(): + # Create a client + client = generativelanguage_v1alpha.TextServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.EmbedTextRequest( + model="model_value", + ) + + # Make the request + response = await client.embed_text(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.EmbedTextRequest, dict]]): + The request object. Request to get a text embedding from + the model. + model (:class:`str`): + Required. The model name to use with + the format model=models/{model}. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + text (:class:`str`): + Optional. The free-form input text + that the model will turn into an + embedding. + + This corresponds to the ``text`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.EmbedTextResponse: + The response to a EmbedTextRequest. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, text]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, text_service.EmbedTextRequest): + request = text_service.EmbedTextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if text is not None: + request.text = text + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.embed_text + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def batch_embed_text( + self, + request: Optional[Union[text_service.BatchEmbedTextRequest, dict]] = None, + *, + model: Optional[str] = None, + texts: Optional[MutableSequence[str]] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> text_service.BatchEmbedTextResponse: + r"""Generates multiple embeddings from the model given + input text in a synchronous call. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_batch_embed_text(): + # Create a client + client = generativelanguage_v1alpha.TextServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.BatchEmbedTextRequest( + model="model_value", + ) + + # Make the request + response = await client.batch_embed_text(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.BatchEmbedTextRequest, dict]]): + The request object. Batch request to get a text embedding + from the model. + model (:class:`str`): + Required. The name of the ``Model`` to use for + generating the embedding. Examples: + models/embedding-gecko-001 + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + texts (:class:`MutableSequence[str]`): + Optional. The free-form input texts + that the model will turn into an + embedding. The current limit is 100 + texts, over which an error will be + thrown. + + This corresponds to the ``texts`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.BatchEmbedTextResponse: + The response to a EmbedTextRequest. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, texts]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, text_service.BatchEmbedTextRequest): + request = text_service.BatchEmbedTextRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if texts: + request.texts.extend(texts) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.batch_embed_text + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def count_text_tokens( + self, + request: Optional[Union[text_service.CountTextTokensRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[text_service.TextPrompt] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> text_service.CountTextTokensResponse: + r"""Runs a model's tokenizer on a text and returns the + token count. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + async def sample_count_text_tokens(): + # Create a client + client = generativelanguage_v1alpha.TextServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1alpha.CountTextTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.count_text_tokens(request=request) + + # Handle the response + print(response) + + Args: + request (Optional[Union[google.ai.generativelanguage_v1alpha.types.CountTextTokensRequest, dict]]): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + model (:class:`str`): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (:class:`google.ai.generativelanguage_v1alpha.types.TextPrompt`): + Required. The free-form input text + given to the model as a prompt. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.CountTextTokensResponse: + A response from CountTextTokens. + + It returns the model's token_count for the prompt. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, text_service.CountTextTokensRequest): + request = text_service.CountTextTokensRequest(request) + + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._client._transport._wrapped_methods[ + self._client._transport.count_text_tokens + ] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry_async.AsyncRetry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self.transport._wrapped_methods[self._client._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._client._validate_universe_domain() + + # Send the request. + response = await rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + async def __aenter__(self) -> "TextServiceAsyncClient": + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.transport.close() + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("TextServiceAsyncClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/client.py new file mode 100644 index 000000000000..39955a0126e3 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/client.py @@ -0,0 +1,1386 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +import logging as std_logging +import os +import re +from typing import ( + Callable, + Dict, + Mapping, + MutableMapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) +import warnings + +from google.api_core import client_options as client_options_lib +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.exceptions import MutualTLSChannelError # type: ignore +from google.auth.transport import mtls # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + +from google.longrunning import operations_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.types import safety, text_service + +from .transports.base import DEFAULT_CLIENT_INFO, TextServiceTransport +from .transports.grpc import TextServiceGrpcTransport +from .transports.grpc_asyncio import TextServiceGrpcAsyncIOTransport +from .transports.rest import TextServiceRestTransport + + +class TextServiceClientMeta(type): + """Metaclass for the TextService client. + + This provides class-level methods for building and retrieving + support objects (e.g. transport) without polluting the client instance + objects. + """ + + _transport_registry = OrderedDict() # type: Dict[str, Type[TextServiceTransport]] + _transport_registry["grpc"] = TextServiceGrpcTransport + _transport_registry["grpc_asyncio"] = TextServiceGrpcAsyncIOTransport + _transport_registry["rest"] = TextServiceRestTransport + + def get_transport_class( + cls, + label: Optional[str] = None, + ) -> Type[TextServiceTransport]: + """Returns an appropriate transport class. + + Args: + label: The name of the desired transport. If none is + provided, then the first transport in the registry is used. + + Returns: + The transport class to use. + """ + # If a specific transport is requested, return that one. + if label: + return cls._transport_registry[label] + + # No transport is requested; return the default (that is, the first one + # in the dictionary). + return next(iter(cls._transport_registry.values())) + + +class TextServiceClient(metaclass=TextServiceClientMeta): + """API for using Generative Language Models (GLMs) trained to + generate text. + Also known as Large Language Models (LLM)s, these generate text + given an input prompt from the user. + """ + + @staticmethod + def _get_default_mtls_endpoint(api_endpoint): + """Converts api endpoint to mTLS endpoint. + + Convert "*.sandbox.googleapis.com" and "*.googleapis.com" to + "*.mtls.sandbox.googleapis.com" and "*.mtls.googleapis.com" respectively. + Args: + api_endpoint (Optional[str]): the api endpoint to convert. + Returns: + str: converted mTLS api endpoint. + """ + if not api_endpoint: + return api_endpoint + + mtls_endpoint_re = re.compile( + r"(?P[^.]+)(?P\.mtls)?(?P\.sandbox)?(?P\.googleapis\.com)?" + ) + + m = mtls_endpoint_re.match(api_endpoint) + name, mtls, sandbox, googledomain = m.groups() + if mtls or not googledomain: + return api_endpoint + + if sandbox: + return api_endpoint.replace( + "sandbox.googleapis.com", "mtls.sandbox.googleapis.com" + ) + + return api_endpoint.replace(".googleapis.com", ".mtls.googleapis.com") + + # Note: DEFAULT_ENDPOINT is deprecated. Use _DEFAULT_ENDPOINT_TEMPLATE instead. + DEFAULT_ENDPOINT = "generativelanguage.googleapis.com" + DEFAULT_MTLS_ENDPOINT = _get_default_mtls_endpoint.__func__( # type: ignore + DEFAULT_ENDPOINT + ) + + _DEFAULT_ENDPOINT_TEMPLATE = "generativelanguage.{UNIVERSE_DOMAIN}" + _DEFAULT_UNIVERSE = "googleapis.com" + + @classmethod + def from_service_account_info(cls, info: dict, *args, **kwargs): + """Creates an instance of this client using the provided credentials + info. + + Args: + info (dict): The service account private key info. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TextServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_info(info) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + @classmethod + def from_service_account_file(cls, filename: str, *args, **kwargs): + """Creates an instance of this client using the provided credentials + file. + + Args: + filename (str): The path to the service account private key json + file. + args: Additional arguments to pass to the constructor. + kwargs: Additional arguments to pass to the constructor. + + Returns: + TextServiceClient: The constructed client. + """ + credentials = service_account.Credentials.from_service_account_file(filename) + kwargs["credentials"] = credentials + return cls(*args, **kwargs) + + from_service_account_json = from_service_account_file + + @property + def transport(self) -> TextServiceTransport: + """Returns the transport used by the client instance. + + Returns: + TextServiceTransport: The transport used by the client + instance. + """ + return self._transport + + @staticmethod + def model_path( + model: str, + ) -> str: + """Returns a fully-qualified model string.""" + return "models/{model}".format( + model=model, + ) + + @staticmethod + def parse_model_path(path: str) -> Dict[str, str]: + """Parses a model path into its component segments.""" + m = re.match(r"^models/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_billing_account_path( + billing_account: str, + ) -> str: + """Returns a fully-qualified billing_account string.""" + return "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + + @staticmethod + def parse_common_billing_account_path(path: str) -> Dict[str, str]: + """Parse a billing_account path into its component segments.""" + m = re.match(r"^billingAccounts/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_folder_path( + folder: str, + ) -> str: + """Returns a fully-qualified folder string.""" + return "folders/{folder}".format( + folder=folder, + ) + + @staticmethod + def parse_common_folder_path(path: str) -> Dict[str, str]: + """Parse a folder path into its component segments.""" + m = re.match(r"^folders/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_organization_path( + organization: str, + ) -> str: + """Returns a fully-qualified organization string.""" + return "organizations/{organization}".format( + organization=organization, + ) + + @staticmethod + def parse_common_organization_path(path: str) -> Dict[str, str]: + """Parse a organization path into its component segments.""" + m = re.match(r"^organizations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_project_path( + project: str, + ) -> str: + """Returns a fully-qualified project string.""" + return "projects/{project}".format( + project=project, + ) + + @staticmethod + def parse_common_project_path(path: str) -> Dict[str, str]: + """Parse a project path into its component segments.""" + m = re.match(r"^projects/(?P.+?)$", path) + return m.groupdict() if m else {} + + @staticmethod + def common_location_path( + project: str, + location: str, + ) -> str: + """Returns a fully-qualified location string.""" + return "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + + @staticmethod + def parse_common_location_path(path: str) -> Dict[str, str]: + """Parse a location path into its component segments.""" + m = re.match(r"^projects/(?P.+?)/locations/(?P.+?)$", path) + return m.groupdict() if m else {} + + @classmethod + def get_mtls_endpoint_and_cert_source( + cls, client_options: Optional[client_options_lib.ClientOptions] = None + ): + """Deprecated. Return the API endpoint and client cert source for mutual TLS. + + The client cert source is determined in the following order: + (1) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is not "true", the + client cert source is None. + (2) if `client_options.client_cert_source` is provided, use the provided one; if the + default client cert source exists, use the default one; otherwise the client cert + source is None. + + The API endpoint is determined in the following order: + (1) if `client_options.api_endpoint` if provided, use the provided one. + (2) if `GOOGLE_API_USE_CLIENT_CERTIFICATE` environment variable is "always", use the + default mTLS endpoint; if the environment variable is "never", use the default API + endpoint; otherwise if client cert source exists, use the default mTLS endpoint, otherwise + use the default API endpoint. + + More details can be found at https://google.aip.dev/auth/4114. + + Args: + client_options (google.api_core.client_options.ClientOptions): Custom options for the + client. Only the `api_endpoint` and `client_cert_source` properties may be used + in this method. + + Returns: + Tuple[str, Callable[[], Tuple[bytes, bytes]]]: returns the API endpoint and the + client cert source to use. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If any errors happen. + """ + + warnings.warn( + "get_mtls_endpoint_and_cert_source is deprecated. Use the api_endpoint property instead.", + DeprecationWarning, + ) + if client_options is None: + client_options = client_options_lib.ClientOptions() + use_client_cert = os.getenv("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Figure out the client cert source to use. + client_cert_source = None + if use_client_cert == "true": + if client_options.client_cert_source: + client_cert_source = client_options.client_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + + # Figure out which api endpoint to use. + if client_options.api_endpoint is not None: + api_endpoint = client_options.api_endpoint + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + api_endpoint = cls.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = cls.DEFAULT_ENDPOINT + + return api_endpoint, client_cert_source + + @staticmethod + def _read_environment_variables(): + """Returns the environment variables used by the client. + + Returns: + Tuple[bool, str, str]: returns the GOOGLE_API_USE_CLIENT_CERTIFICATE, + GOOGLE_API_USE_MTLS_ENDPOINT, and GOOGLE_CLOUD_UNIVERSE_DOMAIN environment variables. + + Raises: + ValueError: If GOOGLE_API_USE_CLIENT_CERTIFICATE is not + any of ["true", "false"]. + google.auth.exceptions.MutualTLSChannelError: If GOOGLE_API_USE_MTLS_ENDPOINT + is not any of ["auto", "never", "always"]. + """ + use_client_cert = os.getenv( + "GOOGLE_API_USE_CLIENT_CERTIFICATE", "false" + ).lower() + use_mtls_endpoint = os.getenv("GOOGLE_API_USE_MTLS_ENDPOINT", "auto").lower() + universe_domain_env = os.getenv("GOOGLE_CLOUD_UNIVERSE_DOMAIN") + if use_client_cert not in ("true", "false"): + raise ValueError( + "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + if use_mtls_endpoint not in ("auto", "never", "always"): + raise MutualTLSChannelError( + "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + return use_client_cert == "true", use_mtls_endpoint, universe_domain_env + + @staticmethod + def _get_client_cert_source(provided_cert_source, use_cert_flag): + """Return the client cert source to be used by the client. + + Args: + provided_cert_source (bytes): The client certificate source provided. + use_cert_flag (bool): A flag indicating whether to use the client certificate. + + Returns: + bytes or None: The client cert source to be used by the client. + """ + client_cert_source = None + if use_cert_flag: + if provided_cert_source: + client_cert_source = provided_cert_source + elif mtls.has_default_client_cert_source(): + client_cert_source = mtls.default_client_cert_source() + return client_cert_source + + @staticmethod + def _get_api_endpoint( + api_override, client_cert_source, universe_domain, use_mtls_endpoint + ): + """Return the API endpoint used by the client. + + Args: + api_override (str): The API endpoint override. If specified, this is always + the return value of this function and the other arguments are not used. + client_cert_source (bytes): The client certificate source used by the client. + universe_domain (str): The universe domain used by the client. + use_mtls_endpoint (str): How to use the mTLS endpoint, which depends also on the other parameters. + Possible values are "always", "auto", or "never". + + Returns: + str: The API endpoint to be used by the client. + """ + if api_override is not None: + api_endpoint = api_override + elif use_mtls_endpoint == "always" or ( + use_mtls_endpoint == "auto" and client_cert_source + ): + _default_universe = TextServiceClient._DEFAULT_UNIVERSE + if universe_domain != _default_universe: + raise MutualTLSChannelError( + f"mTLS is not supported in any universe other than {_default_universe}." + ) + api_endpoint = TextServiceClient.DEFAULT_MTLS_ENDPOINT + else: + api_endpoint = TextServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=universe_domain + ) + return api_endpoint + + @staticmethod + def _get_universe_domain( + client_universe_domain: Optional[str], universe_domain_env: Optional[str] + ) -> str: + """Return the universe domain used by the client. + + Args: + client_universe_domain (Optional[str]): The universe domain configured via the client options. + universe_domain_env (Optional[str]): The universe domain configured via the "GOOGLE_CLOUD_UNIVERSE_DOMAIN" environment variable. + + Returns: + str: The universe domain to be used by the client. + + Raises: + ValueError: If the universe domain is an empty string. + """ + universe_domain = TextServiceClient._DEFAULT_UNIVERSE + if client_universe_domain is not None: + universe_domain = client_universe_domain + elif universe_domain_env is not None: + universe_domain = universe_domain_env + if len(universe_domain.strip()) == 0: + raise ValueError("Universe Domain cannot be an empty string.") + return universe_domain + + def _validate_universe_domain(self): + """Validates client's and credentials' universe domains are consistent. + + Returns: + bool: True iff the configured universe domain is valid. + + Raises: + ValueError: If the configured universe domain is not valid. + """ + + # NOTE (b/349488459): universe validation is disabled until further notice. + return True + + @property + def api_endpoint(self): + """Return the API endpoint used by the client instance. + + Returns: + str: The API endpoint used by the client instance. + """ + return self._api_endpoint + + @property + def universe_domain(self) -> str: + """Return the universe domain used by the client instance. + + Returns: + str: The universe domain used by the client instance. + """ + return self._universe_domain + + def __init__( + self, + *, + credentials: Optional[ga_credentials.Credentials] = None, + transport: Optional[ + Union[str, TextServiceTransport, Callable[..., TextServiceTransport]] + ] = None, + client_options: Optional[Union[client_options_lib.ClientOptions, dict]] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + ) -> None: + """Instantiates the text service client. + + Args: + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + transport (Optional[Union[str,TextServiceTransport,Callable[..., TextServiceTransport]]]): + The transport to use, or a Callable that constructs and returns a new transport. + If a Callable is given, it will be called with the same set of initialization + arguments as used in the TextServiceTransport constructor. + If set to None, a transport is chosen automatically. + client_options (Optional[Union[google.api_core.client_options.ClientOptions, dict]]): + Custom options for the client. + + 1. The ``api_endpoint`` property can be used to override the + default endpoint provided by the client when ``transport`` is + not explicitly provided. Only if this property is not set and + ``transport`` was not explicitly provided, the endpoint is + determined by the GOOGLE_API_USE_MTLS_ENDPOINT environment + variable, which have one of the following values: + "always" (always use the default mTLS endpoint), "never" (always + use the default regular endpoint) and "auto" (auto-switch to the + default mTLS endpoint if client certificate is present; this is + the default value). + + 2. If the GOOGLE_API_USE_CLIENT_CERTIFICATE environment variable + is "true", then the ``client_cert_source`` property can be used + to provide a client certificate for mTLS transport. If + not provided, the default SSL client certificate will be used if + present. If GOOGLE_API_USE_CLIENT_CERTIFICATE is "false" or not + set, no client certificate will be used. + + 3. The ``universe_domain`` property can be used to override the + default "googleapis.com" universe. Note that the ``api_endpoint`` + property still takes precedence; and ``universe_domain`` is + currently not supported for mTLS. + + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + """ + self._client_options = client_options + if isinstance(self._client_options, dict): + self._client_options = client_options_lib.from_dict(self._client_options) + if self._client_options is None: + self._client_options = client_options_lib.ClientOptions() + self._client_options = cast( + client_options_lib.ClientOptions, self._client_options + ) + + universe_domain_opt = getattr(self._client_options, "universe_domain", None) + + ( + self._use_client_cert, + self._use_mtls_endpoint, + self._universe_domain_env, + ) = TextServiceClient._read_environment_variables() + self._client_cert_source = TextServiceClient._get_client_cert_source( + self._client_options.client_cert_source, self._use_client_cert + ) + self._universe_domain = TextServiceClient._get_universe_domain( + universe_domain_opt, self._universe_domain_env + ) + self._api_endpoint = None # updated below, depending on `transport` + + # Initialize the universe domain validation. + self._is_universe_domain_valid = False + + if CLIENT_LOGGING_SUPPORTED: # pragma: NO COVER + # Setup logging. + client_logging.initialize_logging() + + api_key_value = getattr(self._client_options, "api_key", None) + if api_key_value and credentials: + raise ValueError( + "client_options.api_key and credentials are mutually exclusive" + ) + + # Save or instantiate the transport. + # Ordinarily, we provide the transport, but allowing a custom transport + # instance provides an extensibility point for unusual situations. + transport_provided = isinstance(transport, TextServiceTransport) + if transport_provided: + # transport is a TextServiceTransport instance. + if credentials or self._client_options.credentials_file or api_key_value: + raise ValueError( + "When providing a transport instance, " + "provide its credentials directly." + ) + if self._client_options.scopes: + raise ValueError( + "When providing a transport instance, provide its scopes " + "directly." + ) + self._transport = cast(TextServiceTransport, transport) + self._api_endpoint = self._transport.host + + self._api_endpoint = self._api_endpoint or TextServiceClient._get_api_endpoint( + self._client_options.api_endpoint, + self._client_cert_source, + self._universe_domain, + self._use_mtls_endpoint, + ) + + if not transport_provided: + import google.auth._default # type: ignore + + if api_key_value and hasattr( + google.auth._default, "get_api_key_credentials" + ): + credentials = google.auth._default.get_api_key_credentials( + api_key_value + ) + + transport_init: Union[ + Type[TextServiceTransport], Callable[..., TextServiceTransport] + ] = ( + TextServiceClient.get_transport_class(transport) + if isinstance(transport, str) or transport is None + else cast(Callable[..., TextServiceTransport], transport) + ) + # initialize with the provided callable or the passed in class + self._transport = transport_init( + credentials=credentials, + credentials_file=self._client_options.credentials_file, + host=self._api_endpoint, + scopes=self._client_options.scopes, + client_cert_source_for_mtls=self._client_cert_source, + quota_project_id=self._client_options.quota_project_id, + client_info=client_info, + always_use_jwt_access=True, + api_audience=self._client_options.api_audience, + ) + + if "async" not in str(self._transport): + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ): # pragma: NO COVER + _LOGGER.debug( + "Created client `google.ai.generativelanguage_v1alpha.TextServiceClient`.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "universeDomain": getattr( + self._transport._credentials, "universe_domain", "" + ), + "credentialsType": f"{type(self._transport._credentials).__module__}.{type(self._transport._credentials).__qualname__}", + "credentialsInfo": getattr( + self.transport._credentials, "get_cred_info", lambda: None + )(), + } + if hasattr(self._transport, "_credentials") + else { + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "credentialsType": None, + }, + ) + + def generate_text( + self, + request: Optional[Union[text_service.GenerateTextRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[text_service.TextPrompt] = None, + temperature: Optional[float] = None, + candidate_count: Optional[int] = None, + max_output_tokens: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> text_service.GenerateTextResponse: + r"""Generates a response from the model given an input + message. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_generate_text(): + # Create a client + client = generativelanguage_v1alpha.TextServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1alpha.GenerateTextRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.generate_text(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.GenerateTextRequest, dict]): + The request object. Request to generate a text completion + response from the model. + model (str): + Required. The name of the ``Model`` or ``TunedModel`` to + use for generating the completion. Examples: + models/text-bison-001 + tunedModels/sentence-translator-u3b7m + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (google.ai.generativelanguage_v1alpha.types.TextPrompt): + Required. The free-form input text + given to the model as a prompt. + Given a prompt, the model will generate + a TextCompletion response it predicts as + the completion of the input text. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + temperature (float): + Optional. Controls the randomness of the output. Note: + The default value varies by model, see the + ``Model.temperature`` attribute of the ``Model`` + returned the ``getModel`` function. + + Values can range from [0.0,1.0], inclusive. A value + closer to 1.0 will produce responses that are more + varied and creative, while a value closer to 0.0 will + typically result in more straightforward responses from + the model. + + This corresponds to the ``temperature`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + candidate_count (int): + Optional. Number of generated responses to return. + + This value must be between [1, 8], inclusive. If unset, + this will default to 1. + + This corresponds to the ``candidate_count`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + max_output_tokens (int): + Optional. The maximum number of tokens to include in a + candidate. + + If unset, this will default to output_token_limit + specified in the ``Model`` specification. + + This corresponds to the ``max_output_tokens`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_p (float): + Optional. The maximum cumulative probability of tokens + to consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Tokens are sorted based on their assigned probabilities + so that only the most likely tokens are considered. + Top-k sampling directly limits the maximum number of + tokens to consider, while Nucleus sampling limits number + of tokens based on the cumulative probability. + + Note: The default value varies by model, see the + ``Model.top_p`` attribute of the ``Model`` returned the + ``getModel`` function. + + This corresponds to the ``top_p`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + top_k (int): + Optional. The maximum number of tokens to consider when + sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most + probable tokens. Defaults to 40. + + Note: The default value varies by model, see the + ``Model.top_k`` attribute of the ``Model`` returned the + ``getModel`` function. + + This corresponds to the ``top_k`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.GenerateTextResponse: + The response from the model, + including candidate completions. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any( + [ + model, + prompt, + temperature, + candidate_count, + max_output_tokens, + top_p, + top_k, + ] + ) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, text_service.GenerateTextRequest): + request = text_service.GenerateTextRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + if temperature is not None: + request.temperature = temperature + if candidate_count is not None: + request.candidate_count = candidate_count + if max_output_tokens is not None: + request.max_output_tokens = max_output_tokens + if top_p is not None: + request.top_p = top_p + if top_k is not None: + request.top_k = top_k + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.generate_text] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def embed_text( + self, + request: Optional[Union[text_service.EmbedTextRequest, dict]] = None, + *, + model: Optional[str] = None, + text: Optional[str] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> text_service.EmbedTextResponse: + r"""Generates an embedding from the model given an input + message. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_embed_text(): + # Create a client + client = generativelanguage_v1alpha.TextServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.EmbedTextRequest( + model="model_value", + ) + + # Make the request + response = client.embed_text(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.EmbedTextRequest, dict]): + The request object. Request to get a text embedding from + the model. + model (str): + Required. The model name to use with + the format model=models/{model}. + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + text (str): + Optional. The free-form input text + that the model will turn into an + embedding. + + This corresponds to the ``text`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.EmbedTextResponse: + The response to a EmbedTextRequest. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, text]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, text_service.EmbedTextRequest): + request = text_service.EmbedTextRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if text is not None: + request.text = text + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.embed_text] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def batch_embed_text( + self, + request: Optional[Union[text_service.BatchEmbedTextRequest, dict]] = None, + *, + model: Optional[str] = None, + texts: Optional[MutableSequence[str]] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> text_service.BatchEmbedTextResponse: + r"""Generates multiple embeddings from the model given + input text in a synchronous call. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_batch_embed_text(): + # Create a client + client = generativelanguage_v1alpha.TextServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.BatchEmbedTextRequest( + model="model_value", + ) + + # Make the request + response = client.batch_embed_text(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.BatchEmbedTextRequest, dict]): + The request object. Batch request to get a text embedding + from the model. + model (str): + Required. The name of the ``Model`` to use for + generating the embedding. Examples: + models/embedding-gecko-001 + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + texts (MutableSequence[str]): + Optional. The free-form input texts + that the model will turn into an + embedding. The current limit is 100 + texts, over which an error will be + thrown. + + This corresponds to the ``texts`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.BatchEmbedTextResponse: + The response to a EmbedTextRequest. + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, texts]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, text_service.BatchEmbedTextRequest): + request = text_service.BatchEmbedTextRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if texts is not None: + request.texts = texts + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.batch_embed_text] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def count_text_tokens( + self, + request: Optional[Union[text_service.CountTextTokensRequest, dict]] = None, + *, + model: Optional[str] = None, + prompt: Optional[text_service.TextPrompt] = None, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> text_service.CountTextTokensResponse: + r"""Runs a model's tokenizer on a text and returns the + token count. + + .. code-block:: python + + # This snippet has been automatically generated and should be regarded as a + # code template only. + # It will require modifications to work: + # - It may require correct/in-range values for request initialization. + # - It may require specifying regional endpoints when creating the service + # client as shown in: + # https://googleapis.dev/python/google-api-core/latest/client_options.html + from google.ai import generativelanguage_v1alpha + + def sample_count_text_tokens(): + # Create a client + client = generativelanguage_v1alpha.TextServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1alpha.CountTextTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.count_text_tokens(request=request) + + # Handle the response + print(response) + + Args: + request (Union[google.ai.generativelanguage_v1alpha.types.CountTextTokensRequest, dict]): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + model (str): + Required. The model's resource name. This serves as an + ID for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + + This corresponds to the ``model`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + prompt (google.ai.generativelanguage_v1alpha.types.TextPrompt): + Required. The free-form input text + given to the model as a prompt. + + This corresponds to the ``prompt`` field + on the ``request`` instance; if ``request`` is provided, this + should not be set. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + google.ai.generativelanguage_v1alpha.types.CountTextTokensResponse: + A response from CountTextTokens. + + It returns the model's token_count for the prompt. + + """ + # Create or coerce a protobuf request object. + # - Quick check: If we got a request object, we should *not* have + # gotten any keyword arguments that map to the request. + has_flattened_params = any([model, prompt]) + if request is not None and has_flattened_params: + raise ValueError( + "If the `request` argument is set, then none of " + "the individual field arguments should be set." + ) + + # - Use the request object if provided (there's no risk of modifying the input as + # there are no flattened fields), or create one. + if not isinstance(request, text_service.CountTextTokensRequest): + request = text_service.CountTextTokensRequest(request) + # If we have keyword arguments corresponding to fields on the + # request, apply these. + if model is not None: + request.model = model + if prompt is not None: + request.prompt = prompt + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.count_text_tokens] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("model", request.model),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def __enter__(self) -> "TextServiceClient": + return self + + def __exit__(self, type, value, traceback): + """Releases underlying transport's resources. + + .. warning:: + ONLY use as a context manager if the transport is NOT shared + with other clients! Exiting the with block will CLOSE the transport + and may cause errors in other clients! + """ + self.transport.close() + + def list_operations( + self, + request: Optional[operations_pb2.ListOperationsRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Lists operations that match the specified filter in the request. + + Args: + request (:class:`~.operations_pb2.ListOperationsRequest`): + The request object. Request message for + `ListOperations` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.ListOperationsResponse: + Response message for ``ListOperations`` method. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.ListOperationsRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.list_operations] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + def get_operation( + self, + request: Optional[operations_pb2.GetOperationRequest] = None, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Union[float, object] = gapic_v1.method.DEFAULT, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Gets the latest state of a long-running operation. + + Args: + request (:class:`~.operations_pb2.GetOperationRequest`): + The request object. Request message for + `GetOperation` method. + retry (google.api_core.retry.Retry): Designation of what errors, + if any, should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + Returns: + ~.operations_pb2.Operation: + An ``Operation`` object. + """ + # Create or coerce a protobuf request object. + # The request isn't a proto-plus wrapped type, + # so it must be constructed via keyword expansion. + if isinstance(request, dict): + request = operations_pb2.GetOperationRequest(**request) + + # Wrap the RPC method; this adds retry and timeout information, + # and friendly error handling. + rpc = self._transport._wrapped_methods[self._transport.get_operation] + + # Certain fields should be provided within the metadata header; + # add these here. + metadata = tuple(metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("name", request.name),)), + ) + + # Validate the universe domain. + self._validate_universe_domain() + + # Send the request. + response = rpc( + request, + retry=retry, + timeout=timeout, + metadata=metadata, + ) + + # Done; return the response. + return response + + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +__all__ = ("TextServiceClient",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/README.rst b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/README.rst new file mode 100644 index 000000000000..b603a3ea7af6 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/README.rst @@ -0,0 +1,9 @@ + +transport inheritance structure +_______________________________ + +`TextServiceTransport` is the ABC for all transports. +- public child `TextServiceGrpcTransport` for sync gRPC transport (defined in `grpc.py`). +- public child `TextServiceGrpcAsyncIOTransport` for async gRPC transport (defined in `grpc_asyncio.py`). +- private child `_BaseTextServiceRestTransport` for base REST transport with inner classes `_BaseMETHOD` (defined in `rest_base.py`). +- public child `TextServiceRestTransport` for sync REST transport with inner classes `METHOD` derived from the parent's corresponding `_BaseMETHOD` classes (defined in `rest.py`). diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/__init__.py new file mode 100644 index 000000000000..8670693e2f18 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/__init__.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from collections import OrderedDict +from typing import Dict, Type + +from .base import TextServiceTransport +from .grpc import TextServiceGrpcTransport +from .grpc_asyncio import TextServiceGrpcAsyncIOTransport +from .rest import TextServiceRestInterceptor, TextServiceRestTransport + +# Compile a registry of transports. +_transport_registry = OrderedDict() # type: Dict[str, Type[TextServiceTransport]] +_transport_registry["grpc"] = TextServiceGrpcTransport +_transport_registry["grpc_asyncio"] = TextServiceGrpcAsyncIOTransport +_transport_registry["rest"] = TextServiceRestTransport + +__all__ = ( + "TextServiceTransport", + "TextServiceGrpcTransport", + "TextServiceGrpcAsyncIOTransport", + "TextServiceRestTransport", + "TextServiceRestInterceptor", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/base.py new file mode 100644 index 000000000000..25b16f03ca97 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/base.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +from typing import Awaitable, Callable, Dict, Optional, Sequence, Union + +import google.api_core +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1 +from google.api_core import retry as retries +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account # type: ignore + +from google.ai.generativelanguage_v1alpha import gapic_version as package_version +from google.ai.generativelanguage_v1alpha.types import text_service + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=package_version.__version__ +) + + +class TextServiceTransport(abc.ABC): + """Abstract transport class for TextService.""" + + AUTH_SCOPES = () + + DEFAULT_HOST: str = "generativelanguage.googleapis.com" + + def __init__( + self, + *, + host: str = DEFAULT_HOST, + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + **kwargs, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A list of scopes. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + """ + + scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES} + + # Save the scopes. + self._scopes = scopes + if not hasattr(self, "_ignore_credentials"): + self._ignore_credentials: bool = False + + # If no credentials are provided, then determine the appropriate + # defaults. + if credentials and credentials_file: + raise core_exceptions.DuplicateCredentialArgs( + "'credentials_file' and 'credentials' are mutually exclusive" + ) + + if credentials_file is not None: + credentials, _ = google.auth.load_credentials_from_file( + credentials_file, **scopes_kwargs, quota_project_id=quota_project_id + ) + elif credentials is None and not self._ignore_credentials: + credentials, _ = google.auth.default( + **scopes_kwargs, quota_project_id=quota_project_id + ) + # Don't apply audience if the credentials file passed from user. + if hasattr(credentials, "with_gdch_audience"): + credentials = credentials.with_gdch_audience( + api_audience if api_audience else host + ) + + # If the credentials are service account credentials, then always try to use self signed JWT. + if ( + always_use_jwt_access + and isinstance(credentials, service_account.Credentials) + and hasattr(service_account.Credentials, "with_always_use_jwt_access") + ): + credentials = credentials.with_always_use_jwt_access(True) + + # Save the credentials. + self._credentials = credentials + + # Save the hostname. Default to port 443 (HTTPS) if none is specified. + if ":" not in host: + host += ":443" + self._host = host + + @property + def host(self): + return self._host + + def _prep_wrapped_messages(self, client_info): + # Precompute the wrapped methods. + self._wrapped_methods = { + self.generate_text: gapic_v1.method.wrap_method( + self.generate_text, + default_timeout=None, + client_info=client_info, + ), + self.embed_text: gapic_v1.method.wrap_method( + self.embed_text, + default_timeout=None, + client_info=client_info, + ), + self.batch_embed_text: gapic_v1.method.wrap_method( + self.batch_embed_text, + default_timeout=None, + client_info=client_info, + ), + self.count_text_tokens: gapic_v1.method.wrap_method( + self.count_text_tokens, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: gapic_v1.method.wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: gapic_v1.method.wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def close(self): + """Closes resources associated with the transport. + + .. warning:: + Only call this method if the transport is NOT shared + with other clients - this may cause errors in other clients! + """ + raise NotImplementedError() + + @property + def generate_text( + self, + ) -> Callable[ + [text_service.GenerateTextRequest], + Union[ + text_service.GenerateTextResponse, + Awaitable[text_service.GenerateTextResponse], + ], + ]: + raise NotImplementedError() + + @property + def embed_text( + self, + ) -> Callable[ + [text_service.EmbedTextRequest], + Union[ + text_service.EmbedTextResponse, Awaitable[text_service.EmbedTextResponse] + ], + ]: + raise NotImplementedError() + + @property + def batch_embed_text( + self, + ) -> Callable[ + [text_service.BatchEmbedTextRequest], + Union[ + text_service.BatchEmbedTextResponse, + Awaitable[text_service.BatchEmbedTextResponse], + ], + ]: + raise NotImplementedError() + + @property + def count_text_tokens( + self, + ) -> Callable[ + [text_service.CountTextTokensRequest], + Union[ + text_service.CountTextTokensResponse, + Awaitable[text_service.CountTextTokensResponse], + ], + ]: + raise NotImplementedError() + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], + Union[ + operations_pb2.ListOperationsResponse, + Awaitable[operations_pb2.ListOperationsResponse], + ], + ]: + raise NotImplementedError() + + @property + def get_operation( + self, + ) -> Callable[ + [operations_pb2.GetOperationRequest], + Union[operations_pb2.Operation, Awaitable[operations_pb2.Operation]], + ]: + raise NotImplementedError() + + @property + def kind(self) -> str: + raise NotImplementedError() + + +__all__ = ("TextServiceTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/grpc.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/grpc.py new file mode 100644 index 000000000000..4ea6970fa349 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/grpc.py @@ -0,0 +1,485 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json +import logging as std_logging +import pickle +from typing import Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import gapic_v1, grpc_helpers +import google.auth # type: ignore +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import text_service + +from .base import DEFAULT_CLIENT_INFO, TextServiceTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientInterceptor(grpc.UnaryUnaryClientInterceptor): # pragma: NO COVER + def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": client_call_details.method, + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + + response = continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = response.result() + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response for {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": client_call_details.method, + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class TextServiceGrpcTransport(TextServiceTransport): + """gRPC backend transport for TextService. + + API for using Generative Language Models (GLMs) trained to + generate text. + Also known as Large Language Models (LLM)s, these generate text + given an input prompt from the user. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _stubs: Dict[str, Callable] + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if a ``channel`` instance is provided. + channel (Optional[Union[grpc.Channel, Callable[..., grpc.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTLSChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, grpc.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientInterceptor() + self._logged_channel = grpc.intercept_channel( + self._grpc_channel, self._interceptor + ) + + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> grpc.Channel: + """Create and return a gRPC channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is mutually exclusive with credentials. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + grpc.Channel: A gRPC channel object. + + Raises: + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + + return grpc_helpers.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + @property + def grpc_channel(self) -> grpc.Channel: + """Return the channel designed to connect to this service.""" + return self._grpc_channel + + @property + def generate_text( + self, + ) -> Callable[ + [text_service.GenerateTextRequest], text_service.GenerateTextResponse + ]: + r"""Return a callable for the generate text method over gRPC. + + Generates a response from the model given an input + message. + + Returns: + Callable[[~.GenerateTextRequest], + ~.GenerateTextResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "generate_text" not in self._stubs: + self._stubs["generate_text"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.TextService/GenerateText", + request_serializer=text_service.GenerateTextRequest.serialize, + response_deserializer=text_service.GenerateTextResponse.deserialize, + ) + return self._stubs["generate_text"] + + @property + def embed_text( + self, + ) -> Callable[[text_service.EmbedTextRequest], text_service.EmbedTextResponse]: + r"""Return a callable for the embed text method over gRPC. + + Generates an embedding from the model given an input + message. + + Returns: + Callable[[~.EmbedTextRequest], + ~.EmbedTextResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "embed_text" not in self._stubs: + self._stubs["embed_text"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.TextService/EmbedText", + request_serializer=text_service.EmbedTextRequest.serialize, + response_deserializer=text_service.EmbedTextResponse.deserialize, + ) + return self._stubs["embed_text"] + + @property + def batch_embed_text( + self, + ) -> Callable[ + [text_service.BatchEmbedTextRequest], text_service.BatchEmbedTextResponse + ]: + r"""Return a callable for the batch embed text method over gRPC. + + Generates multiple embeddings from the model given + input text in a synchronous call. + + Returns: + Callable[[~.BatchEmbedTextRequest], + ~.BatchEmbedTextResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_embed_text" not in self._stubs: + self._stubs["batch_embed_text"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.TextService/BatchEmbedText", + request_serializer=text_service.BatchEmbedTextRequest.serialize, + response_deserializer=text_service.BatchEmbedTextResponse.deserialize, + ) + return self._stubs["batch_embed_text"] + + @property + def count_text_tokens( + self, + ) -> Callable[ + [text_service.CountTextTokensRequest], text_service.CountTextTokensResponse + ]: + r"""Return a callable for the count text tokens method over gRPC. + + Runs a model's tokenizer on a text and returns the + token count. + + Returns: + Callable[[~.CountTextTokensRequest], + ~.CountTextTokensResponse]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "count_text_tokens" not in self._stubs: + self._stubs["count_text_tokens"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.TextService/CountTextTokens", + request_serializer=text_service.CountTextTokensRequest.serialize, + response_deserializer=text_service.CountTextTokensResponse.deserialize, + ) + return self._stubs["count_text_tokens"] + + def close(self): + self._logged_channel.close() + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + @property + def kind(self) -> str: + return "grpc" + + +__all__ = ("TextServiceGrpcTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/grpc_asyncio.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/grpc_asyncio.py new file mode 100644 index 000000000000..24519afd29bb --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/grpc_asyncio.py @@ -0,0 +1,536 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import json +import logging as std_logging +import pickle +from typing import Awaitable, Callable, Dict, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, grpc_helpers_async +from google.api_core import retry_async as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.grpc import SslCredentials # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf.json_format import MessageToJson +import google.protobuf.message +import grpc # type: ignore +from grpc.experimental import aio # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import text_service + +from .base import DEFAULT_CLIENT_INFO, TextServiceTransport +from .grpc import TextServiceGrpcTransport + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = std_logging.getLogger(__name__) + + +class _LoggingClientAIOInterceptor( + grpc.aio.UnaryUnaryClientInterceptor +): # pragma: NO COVER + async def intercept_unary_unary(self, continuation, client_call_details, request): + logging_enabled = CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + std_logging.DEBUG + ) + if logging_enabled: # pragma: NO COVER + request_metadata = client_call_details.metadata + if isinstance(request, proto.Message): + request_payload = type(request).to_json(request) + elif isinstance(request, google.protobuf.message.Message): + request_payload = MessageToJson(request) + else: + request_payload = f"{type(request).__name__}: {pickle.dumps(request)}" + + request_metadata = { + key: value.decode("utf-8") if isinstance(value, bytes) else value + for key, value in request_metadata + } + grpc_request = { + "payload": request_payload, + "requestMethod": "grpc", + "metadata": dict(request_metadata), + } + _LOGGER.debug( + f"Sending request for {client_call_details.method}", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": str(client_call_details.method), + "request": grpc_request, + "metadata": grpc_request["metadata"], + }, + ) + response = await continuation(client_call_details, request) + if logging_enabled: # pragma: NO COVER + response_metadata = await response.trailing_metadata() + # Convert gRPC metadata `` to list of tuples + metadata = ( + dict([(k, str(v)) for k, v in response_metadata]) + if response_metadata + else None + ) + result = await response + if isinstance(result, proto.Message): + response_payload = type(result).to_json(result) + elif isinstance(result, google.protobuf.message.Message): + response_payload = MessageToJson(result) + else: + response_payload = f"{type(result).__name__}: {pickle.dumps(result)}" + grpc_response = { + "payload": response_payload, + "metadata": metadata, + "status": "OK", + } + _LOGGER.debug( + f"Received response to rpc {client_call_details.method}.", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": str(client_call_details.method), + "response": grpc_response, + "metadata": grpc_response["metadata"], + }, + ) + return response + + +class TextServiceGrpcAsyncIOTransport(TextServiceTransport): + """gRPC AsyncIO backend transport for TextService. + + API for using Generative Language Models (GLMs) trained to + generate text. + Also known as Large Language Models (LLM)s, these generate text + given an input prompt from the user. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends protocol buffers over the wire using gRPC (which is built on + top of HTTP/2); the ``grpcio`` package must be installed. + """ + + _grpc_channel: aio.Channel + _stubs: Dict[str, Callable] = {} + + @classmethod + def create_channel( + cls, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + quota_project_id: Optional[str] = None, + **kwargs, + ) -> aio.Channel: + """Create and return a gRPC AsyncIO channel object. + Args: + host (Optional[str]): The host for the channel to use. + credentials (Optional[~.Credentials]): The + authorization credentials to attach to requests. These + credentials identify this application to the service. If + none are specified, the client will attempt to ascertain + the credentials from the environment. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + kwargs (Optional[dict]): Keyword arguments, which are passed to the + channel creation. + Returns: + aio.Channel: A gRPC AsyncIO channel object. + """ + + return grpc_helpers_async.create_channel( + host, + credentials=credentials, + credentials_file=credentials_file, + quota_project_id=quota_project_id, + default_scopes=cls.AUTH_SCOPES, + scopes=scopes, + default_host=cls.DEFAULT_HOST, + **kwargs, + ) + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + channel: Optional[Union[aio.Channel, Callable[..., aio.Channel]]] = None, + api_mtls_endpoint: Optional[str] = None, + client_cert_source: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + ssl_channel_credentials: Optional[grpc.ChannelCredentials] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + This argument is ignored if a ``channel`` instance is provided. + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if a ``channel`` instance is provided. + scopes (Optional[Sequence[str]]): A optional list of scopes needed for this + service. These are only used when credentials are not specified and + are passed to :func:`google.auth.default`. + channel (Optional[Union[aio.Channel, Callable[..., aio.Channel]]]): + A ``Channel`` instance through which to make calls, or a Callable + that constructs and returns one. If set to None, ``self.create_channel`` + is used to create the channel. If a Callable is given, it will be called + with the same arguments as used in ``self.create_channel``. + api_mtls_endpoint (Optional[str]): Deprecated. The mutual TLS endpoint. + If provided, it overrides the ``host`` argument and tries to create + a mutual TLS channel with client SSL credentials from + ``client_cert_source`` or application default SSL credentials. + client_cert_source (Optional[Callable[[], Tuple[bytes, bytes]]]): + Deprecated. A callback to provide client SSL certificate bytes and + private key bytes, both in PEM format. It is ignored if + ``api_mtls_endpoint`` is None. + ssl_channel_credentials (grpc.ChannelCredentials): SSL credentials + for the grpc channel. It is ignored if a ``channel`` instance is provided. + client_cert_source_for_mtls (Optional[Callable[[], Tuple[bytes, bytes]]]): + A callback to provide client certificate bytes and private key bytes, + both in PEM format. It is used to configure a mutual TLS channel. It is + ignored if a ``channel`` instance or ``ssl_channel_credentials`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you're developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + + Raises: + google.auth.exceptions.MutualTlsChannelError: If mutual TLS transport + creation failed for any reason. + google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials`` + and ``credentials_file`` are passed. + """ + self._grpc_channel = None + self._ssl_channel_credentials = ssl_channel_credentials + self._stubs: Dict[str, Callable] = {} + + if api_mtls_endpoint: + warnings.warn("api_mtls_endpoint is deprecated", DeprecationWarning) + if client_cert_source: + warnings.warn("client_cert_source is deprecated", DeprecationWarning) + + if isinstance(channel, aio.Channel): + # Ignore credentials if a channel was passed. + credentials = None + self._ignore_credentials = True + # If a channel was explicitly provided, set it. + self._grpc_channel = channel + self._ssl_channel_credentials = None + else: + if api_mtls_endpoint: + host = api_mtls_endpoint + + # Create SSL credentials with client_cert_source or application + # default SSL credentials. + if client_cert_source: + cert, key = client_cert_source() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + else: + self._ssl_channel_credentials = SslCredentials().ssl_credentials + + else: + if client_cert_source_for_mtls and not ssl_channel_credentials: + cert, key = client_cert_source_for_mtls() + self._ssl_channel_credentials = grpc.ssl_channel_credentials( + certificate_chain=cert, private_key=key + ) + + # The base transport sets the host, credentials and scopes + super().__init__( + host=host, + credentials=credentials, + credentials_file=credentials_file, + scopes=scopes, + quota_project_id=quota_project_id, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + if not self._grpc_channel: + # initialize with the provided callable or the default channel + channel_init = channel or type(self).create_channel + self._grpc_channel = channel_init( + self._host, + # use the credentials which are saved + credentials=self._credentials, + # Set ``credentials_file`` to ``None`` here as + # the credentials that we saved earlier should be used. + credentials_file=None, + scopes=self._scopes, + ssl_credentials=self._ssl_channel_credentials, + quota_project_id=quota_project_id, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + self._interceptor = _LoggingClientAIOInterceptor() + self._grpc_channel._unary_unary_interceptors.append(self._interceptor) + self._logged_channel = self._grpc_channel + self._wrap_with_kind = ( + "kind" in inspect.signature(gapic_v1.method_async.wrap_method).parameters + ) + # Wrap messages. This must be done after self._logged_channel exists + self._prep_wrapped_messages(client_info) + + @property + def grpc_channel(self) -> aio.Channel: + """Create the channel designed to connect to this service. + + This property caches on the instance; repeated calls return + the same channel. + """ + # Return the channel from cache. + return self._grpc_channel + + @property + def generate_text( + self, + ) -> Callable[ + [text_service.GenerateTextRequest], Awaitable[text_service.GenerateTextResponse] + ]: + r"""Return a callable for the generate text method over gRPC. + + Generates a response from the model given an input + message. + + Returns: + Callable[[~.GenerateTextRequest], + Awaitable[~.GenerateTextResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "generate_text" not in self._stubs: + self._stubs["generate_text"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.TextService/GenerateText", + request_serializer=text_service.GenerateTextRequest.serialize, + response_deserializer=text_service.GenerateTextResponse.deserialize, + ) + return self._stubs["generate_text"] + + @property + def embed_text( + self, + ) -> Callable[ + [text_service.EmbedTextRequest], Awaitable[text_service.EmbedTextResponse] + ]: + r"""Return a callable for the embed text method over gRPC. + + Generates an embedding from the model given an input + message. + + Returns: + Callable[[~.EmbedTextRequest], + Awaitable[~.EmbedTextResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "embed_text" not in self._stubs: + self._stubs["embed_text"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.TextService/EmbedText", + request_serializer=text_service.EmbedTextRequest.serialize, + response_deserializer=text_service.EmbedTextResponse.deserialize, + ) + return self._stubs["embed_text"] + + @property + def batch_embed_text( + self, + ) -> Callable[ + [text_service.BatchEmbedTextRequest], + Awaitable[text_service.BatchEmbedTextResponse], + ]: + r"""Return a callable for the batch embed text method over gRPC. + + Generates multiple embeddings from the model given + input text in a synchronous call. + + Returns: + Callable[[~.BatchEmbedTextRequest], + Awaitable[~.BatchEmbedTextResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "batch_embed_text" not in self._stubs: + self._stubs["batch_embed_text"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.TextService/BatchEmbedText", + request_serializer=text_service.BatchEmbedTextRequest.serialize, + response_deserializer=text_service.BatchEmbedTextResponse.deserialize, + ) + return self._stubs["batch_embed_text"] + + @property + def count_text_tokens( + self, + ) -> Callable[ + [text_service.CountTextTokensRequest], + Awaitable[text_service.CountTextTokensResponse], + ]: + r"""Return a callable for the count text tokens method over gRPC. + + Runs a model's tokenizer on a text and returns the + token count. + + Returns: + Callable[[~.CountTextTokensRequest], + Awaitable[~.CountTextTokensResponse]]: + A function that, when called, will call the underlying RPC + on the server. + """ + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "count_text_tokens" not in self._stubs: + self._stubs["count_text_tokens"] = self._logged_channel.unary_unary( + "/google.ai.generativelanguage.v1alpha.TextService/CountTextTokens", + request_serializer=text_service.CountTextTokensRequest.serialize, + response_deserializer=text_service.CountTextTokensResponse.deserialize, + ) + return self._stubs["count_text_tokens"] + + def _prep_wrapped_messages(self, client_info): + """Precompute the wrapped methods, overriding the base class method to use async wrappers.""" + self._wrapped_methods = { + self.generate_text: self._wrap_method( + self.generate_text, + default_timeout=None, + client_info=client_info, + ), + self.embed_text: self._wrap_method( + self.embed_text, + default_timeout=None, + client_info=client_info, + ), + self.batch_embed_text: self._wrap_method( + self.batch_embed_text, + default_timeout=None, + client_info=client_info, + ), + self.count_text_tokens: self._wrap_method( + self.count_text_tokens, + default_timeout=None, + client_info=client_info, + ), + self.get_operation: self._wrap_method( + self.get_operation, + default_timeout=None, + client_info=client_info, + ), + self.list_operations: self._wrap_method( + self.list_operations, + default_timeout=None, + client_info=client_info, + ), + } + + def _wrap_method(self, func, *args, **kwargs): + if self._wrap_with_kind: # pragma: NO COVER + kwargs["kind"] = self.kind + return gapic_v1.method_async.wrap_method(func, *args, **kwargs) + + def close(self): + return self._logged_channel.close() + + @property + def kind(self) -> str: + return "grpc_asyncio" + + @property + def get_operation( + self, + ) -> Callable[[operations_pb2.GetOperationRequest], operations_pb2.Operation]: + r"""Return a callable for the get_operation method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "get_operation" not in self._stubs: + self._stubs["get_operation"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/GetOperation", + request_serializer=operations_pb2.GetOperationRequest.SerializeToString, + response_deserializer=operations_pb2.Operation.FromString, + ) + return self._stubs["get_operation"] + + @property + def list_operations( + self, + ) -> Callable[ + [operations_pb2.ListOperationsRequest], operations_pb2.ListOperationsResponse + ]: + r"""Return a callable for the list_operations method over gRPC.""" + # Generate a "stub function" on-the-fly which will actually make + # the request. + # gRPC handles serialization and deserialization, so we just need + # to pass in the functions for each. + if "list_operations" not in self._stubs: + self._stubs["list_operations"] = self._logged_channel.unary_unary( + "/google.longrunning.Operations/ListOperations", + request_serializer=operations_pb2.ListOperationsRequest.SerializeToString, + response_deserializer=operations_pb2.ListOperationsResponse.FromString, + ) + return self._stubs["list_operations"] + + +__all__ = ("TextServiceGrpcAsyncIOTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/rest.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/rest.py new file mode 100644 index 000000000000..87f2c2bc093c --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/rest.py @@ -0,0 +1,1293 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +import json # type: ignore +import logging +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +import warnings + +from google.api_core import exceptions as core_exceptions +from google.api_core import gapic_v1, rest_helpers, rest_streaming +from google.api_core import retry as retries +from google.auth import credentials as ga_credentials # type: ignore +from google.auth.transport.requests import AuthorizedSession # type: ignore +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import json_format +from requests import __version__ as requests_version + +from google.ai.generativelanguage_v1alpha.types import text_service + +from .base import DEFAULT_CLIENT_INFO as BASE_DEFAULT_CLIENT_INFO +from .rest_base import _BaseTextServiceRestTransport + +try: + OptionalRetry = Union[retries.Retry, gapic_v1.method._MethodDefault, None] +except AttributeError: # pragma: NO COVER + OptionalRetry = Union[retries.Retry, object, None] # type: ignore + +try: + from google.api_core import client_logging # type: ignore + + CLIENT_LOGGING_SUPPORTED = True # pragma: NO COVER +except ImportError: # pragma: NO COVER + CLIENT_LOGGING_SUPPORTED = False + +_LOGGER = logging.getLogger(__name__) + +DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( + gapic_version=BASE_DEFAULT_CLIENT_INFO.gapic_version, + grpc_version=None, + rest_version=f"requests@{requests_version}", +) + + +class TextServiceRestInterceptor: + """Interceptor for TextService. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the TextServiceRestTransport. + + .. code-block:: python + class MyCustomTextServiceInterceptor(TextServiceRestInterceptor): + def pre_batch_embed_text(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_batch_embed_text(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_count_text_tokens(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_count_text_tokens(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_embed_text(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_embed_text(self, response): + logging.log(f"Received response: {response}") + return response + + def pre_generate_text(self, request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + def post_generate_text(self, response): + logging.log(f"Received response: {response}") + return response + + transport = TextServiceRestTransport(interceptor=MyCustomTextServiceInterceptor()) + client = TextServiceClient(transport=transport) + + + """ + + def pre_batch_embed_text( + self, + request: text_service.BatchEmbedTextRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + text_service.BatchEmbedTextRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for batch_embed_text + + Override in a subclass to manipulate the request or metadata + before they are sent to the TextService server. + """ + return request, metadata + + def post_batch_embed_text( + self, response: text_service.BatchEmbedTextResponse + ) -> text_service.BatchEmbedTextResponse: + """Post-rpc interceptor for batch_embed_text + + Override in a subclass to manipulate the response + after it is returned by the TextService server but before + it is returned to user code. + """ + return response + + def pre_count_text_tokens( + self, + request: text_service.CountTextTokensRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + text_service.CountTextTokensRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for count_text_tokens + + Override in a subclass to manipulate the request or metadata + before they are sent to the TextService server. + """ + return request, metadata + + def post_count_text_tokens( + self, response: text_service.CountTextTokensResponse + ) -> text_service.CountTextTokensResponse: + """Post-rpc interceptor for count_text_tokens + + Override in a subclass to manipulate the response + after it is returned by the TextService server but before + it is returned to user code. + """ + return response + + def pre_embed_text( + self, + request: text_service.EmbedTextRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[text_service.EmbedTextRequest, Sequence[Tuple[str, Union[str, bytes]]]]: + """Pre-rpc interceptor for embed_text + + Override in a subclass to manipulate the request or metadata + before they are sent to the TextService server. + """ + return request, metadata + + def post_embed_text( + self, response: text_service.EmbedTextResponse + ) -> text_service.EmbedTextResponse: + """Post-rpc interceptor for embed_text + + Override in a subclass to manipulate the response + after it is returned by the TextService server but before + it is returned to user code. + """ + return response + + def pre_generate_text( + self, + request: text_service.GenerateTextRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + text_service.GenerateTextRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for generate_text + + Override in a subclass to manipulate the request or metadata + before they are sent to the TextService server. + """ + return request, metadata + + def post_generate_text( + self, response: text_service.GenerateTextResponse + ) -> text_service.GenerateTextResponse: + """Post-rpc interceptor for generate_text + + Override in a subclass to manipulate the response + after it is returned by the TextService server but before + it is returned to user code. + """ + return response + + def pre_get_operation( + self, + request: operations_pb2.GetOperationRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.GetOperationRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for get_operation + + Override in a subclass to manipulate the request or metadata + before they are sent to the TextService server. + """ + return request, metadata + + def post_get_operation( + self, response: operations_pb2.Operation + ) -> operations_pb2.Operation: + """Post-rpc interceptor for get_operation + + Override in a subclass to manipulate the response + after it is returned by the TextService server but before + it is returned to user code. + """ + return response + + def pre_list_operations( + self, + request: operations_pb2.ListOperationsRequest, + metadata: Sequence[Tuple[str, Union[str, bytes]]], + ) -> Tuple[ + operations_pb2.ListOperationsRequest, Sequence[Tuple[str, Union[str, bytes]]] + ]: + """Pre-rpc interceptor for list_operations + + Override in a subclass to manipulate the request or metadata + before they are sent to the TextService server. + """ + return request, metadata + + def post_list_operations( + self, response: operations_pb2.ListOperationsResponse + ) -> operations_pb2.ListOperationsResponse: + """Post-rpc interceptor for list_operations + + Override in a subclass to manipulate the response + after it is returned by the TextService server but before + it is returned to user code. + """ + return response + + +@dataclasses.dataclass +class TextServiceRestStub: + _session: AuthorizedSession + _host: str + _interceptor: TextServiceRestInterceptor + + +class TextServiceRestTransport(_BaseTextServiceRestTransport): + """REST backend synchronous transport for TextService. + + API for using Generative Language Models (GLMs) trained to + generate text. + Also known as Large Language Models (LLM)s, these generate text + given an input prompt from the user. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[ga_credentials.Credentials] = None, + credentials_file: Optional[str] = None, + scopes: Optional[Sequence[str]] = None, + client_cert_source_for_mtls: Optional[Callable[[], Tuple[bytes, bytes]]] = None, + quota_project_id: Optional[str] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + interceptor: Optional[TextServiceRestInterceptor] = None, + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[google.auth.credentials.Credentials]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + + credentials_file (Optional[str]): A file with credentials that can + be loaded with :func:`google.auth.load_credentials_from_file`. + This argument is ignored if ``channel`` is provided. + scopes (Optional(Sequence[str])): A list of scopes. This argument is + ignored if ``channel`` is provided. + client_cert_source_for_mtls (Callable[[], Tuple[bytes, bytes]]): Client + certificate to configure mutual TLS HTTP channel. It is ignored + if ``channel`` is provided. + quota_project_id (Optional[str]): An optional project to use for billing + and quota. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + # TODO(yon-mg): resolve other ctor params i.e. scopes, quota, etc. + # TODO: When custom host (api_endpoint) is set, `scopes` must *also* be set on the + # credentials object + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + url_scheme=url_scheme, + api_audience=api_audience, + ) + self._session = AuthorizedSession( + self._credentials, default_host=self.DEFAULT_HOST + ) + if client_cert_source_for_mtls: + self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or TextServiceRestInterceptor() + self._prep_wrapped_messages(client_info) + + class _BatchEmbedText( + _BaseTextServiceRestTransport._BaseBatchEmbedText, TextServiceRestStub + ): + def __hash__(self): + return hash("TextServiceRestTransport.BatchEmbedText") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: text_service.BatchEmbedTextRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> text_service.BatchEmbedTextResponse: + r"""Call the batch embed text method over HTTP. + + Args: + request (~.text_service.BatchEmbedTextRequest): + The request object. Batch request to get a text embedding + from the model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.text_service.BatchEmbedTextResponse: + The response to a EmbedTextRequest. + """ + + http_options = ( + _BaseTextServiceRestTransport._BaseBatchEmbedText._get_http_options() + ) + + request, metadata = self._interceptor.pre_batch_embed_text( + request, metadata + ) + transcoded_request = _BaseTextServiceRestTransport._BaseBatchEmbedText._get_transcoded_request( + http_options, request + ) + + body = _BaseTextServiceRestTransport._BaseBatchEmbedText._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseTextServiceRestTransport._BaseBatchEmbedText._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.TextServiceClient.BatchEmbedText", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": "BatchEmbedText", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = TextServiceRestTransport._BatchEmbedText._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = text_service.BatchEmbedTextResponse() + pb_resp = text_service.BatchEmbedTextResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_batch_embed_text(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = text_service.BatchEmbedTextResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.TextServiceClient.batch_embed_text", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": "BatchEmbedText", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _CountTextTokens( + _BaseTextServiceRestTransport._BaseCountTextTokens, TextServiceRestStub + ): + def __hash__(self): + return hash("TextServiceRestTransport.CountTextTokens") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: text_service.CountTextTokensRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> text_service.CountTextTokensResponse: + r"""Call the count text tokens method over HTTP. + + Args: + request (~.text_service.CountTextTokensRequest): + The request object. Counts the number of tokens in the ``prompt`` sent to a + model. + + Models may tokenize text differently, so each model may + return a different ``token_count``. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.text_service.CountTextTokensResponse: + A response from ``CountTextTokens``. + + It returns the model's ``token_count`` for the + ``prompt``. + + """ + + http_options = ( + _BaseTextServiceRestTransport._BaseCountTextTokens._get_http_options() + ) + + request, metadata = self._interceptor.pre_count_text_tokens( + request, metadata + ) + transcoded_request = _BaseTextServiceRestTransport._BaseCountTextTokens._get_transcoded_request( + http_options, request + ) + + body = _BaseTextServiceRestTransport._BaseCountTextTokens._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = _BaseTextServiceRestTransport._BaseCountTextTokens._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.TextServiceClient.CountTextTokens", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": "CountTextTokens", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = TextServiceRestTransport._CountTextTokens._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = text_service.CountTextTokensResponse() + pb_resp = text_service.CountTextTokensResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_count_text_tokens(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = text_service.CountTextTokensResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.TextServiceClient.count_text_tokens", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": "CountTextTokens", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _EmbedText(_BaseTextServiceRestTransport._BaseEmbedText, TextServiceRestStub): + def __hash__(self): + return hash("TextServiceRestTransport.EmbedText") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: text_service.EmbedTextRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> text_service.EmbedTextResponse: + r"""Call the embed text method over HTTP. + + Args: + request (~.text_service.EmbedTextRequest): + The request object. Request to get a text embedding from + the model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.text_service.EmbedTextResponse: + The response to a EmbedTextRequest. + """ + + http_options = ( + _BaseTextServiceRestTransport._BaseEmbedText._get_http_options() + ) + + request, metadata = self._interceptor.pre_embed_text(request, metadata) + transcoded_request = ( + _BaseTextServiceRestTransport._BaseEmbedText._get_transcoded_request( + http_options, request + ) + ) + + body = _BaseTextServiceRestTransport._BaseEmbedText._get_request_body_json( + transcoded_request + ) + + # Jsonify the query params + query_params = ( + _BaseTextServiceRestTransport._BaseEmbedText._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.TextServiceClient.EmbedText", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": "EmbedText", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = TextServiceRestTransport._EmbedText._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = text_service.EmbedTextResponse() + pb_resp = text_service.EmbedTextResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_embed_text(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = text_service.EmbedTextResponse.to_json(response) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.TextServiceClient.embed_text", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": "EmbedText", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + class _GenerateText( + _BaseTextServiceRestTransport._BaseGenerateText, TextServiceRestStub + ): + def __hash__(self): + return hash("TextServiceRestTransport.GenerateText") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + data=body, + ) + return response + + def __call__( + self, + request: text_service.GenerateTextRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> text_service.GenerateTextResponse: + r"""Call the generate text method over HTTP. + + Args: + request (~.text_service.GenerateTextRequest): + The request object. Request to generate a text completion + response from the model. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + ~.text_service.GenerateTextResponse: + The response from the model, + including candidate completions. + + """ + + http_options = ( + _BaseTextServiceRestTransport._BaseGenerateText._get_http_options() + ) + + request, metadata = self._interceptor.pre_generate_text(request, metadata) + transcoded_request = ( + _BaseTextServiceRestTransport._BaseGenerateText._get_transcoded_request( + http_options, request + ) + ) + + body = ( + _BaseTextServiceRestTransport._BaseGenerateText._get_request_body_json( + transcoded_request + ) + ) + + # Jsonify the query params + query_params = ( + _BaseTextServiceRestTransport._BaseGenerateText._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = type(request).to_json(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.TextServiceClient.GenerateText", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": "GenerateText", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = TextServiceRestTransport._GenerateText._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + body, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + # Return the response + resp = text_service.GenerateTextResponse() + pb_resp = text_service.GenerateTextResponse.pb(resp) + + json_format.Parse(response.content, pb_resp, ignore_unknown_fields=True) + + resp = self._interceptor.post_generate_text(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = text_service.GenerateTextResponse.to_json( + response + ) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.TextServiceClient.generate_text", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": "GenerateText", + "metadata": http_response["headers"], + "httpResponse": http_response, + }, + ) + return resp + + @property + def batch_embed_text( + self, + ) -> Callable[ + [text_service.BatchEmbedTextRequest], text_service.BatchEmbedTextResponse + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._BatchEmbedText(self._session, self._host, self._interceptor) # type: ignore + + @property + def count_text_tokens( + self, + ) -> Callable[ + [text_service.CountTextTokensRequest], text_service.CountTextTokensResponse + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._CountTextTokens(self._session, self._host, self._interceptor) # type: ignore + + @property + def embed_text( + self, + ) -> Callable[[text_service.EmbedTextRequest], text_service.EmbedTextResponse]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._EmbedText(self._session, self._host, self._interceptor) # type: ignore + + @property + def generate_text( + self, + ) -> Callable[ + [text_service.GenerateTextRequest], text_service.GenerateTextResponse + ]: + # The return type is fine, but mypy isn't sophisticated enough to determine what's going on here. + # In C++ this would require a dynamic_cast + return self._GenerateText(self._session, self._host, self._interceptor) # type: ignore + + @property + def get_operation(self): + return self._GetOperation(self._session, self._host, self._interceptor) # type: ignore + + class _GetOperation( + _BaseTextServiceRestTransport._BaseGetOperation, TextServiceRestStub + ): + def __hash__(self): + return hash("TextServiceRestTransport.GetOperation") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.GetOperationRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.Operation: + r"""Call the get operation method over HTTP. + + Args: + request (operations_pb2.GetOperationRequest): + The request object for GetOperation method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.Operation: Response from GetOperation method. + """ + + http_options = ( + _BaseTextServiceRestTransport._BaseGetOperation._get_http_options() + ) + + request, metadata = self._interceptor.pre_get_operation(request, metadata) + transcoded_request = ( + _BaseTextServiceRestTransport._BaseGetOperation._get_transcoded_request( + http_options, request + ) + ) + + # Jsonify the query params + query_params = ( + _BaseTextServiceRestTransport._BaseGetOperation._get_query_params_json( + transcoded_request + ) + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.TextServiceClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": "GetOperation", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = TextServiceRestTransport._GetOperation._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.Operation() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_get_operation(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.TextServiceAsyncClient.GetOperation", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": "GetOperation", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def list_operations(self): + return self._ListOperations(self._session, self._host, self._interceptor) # type: ignore + + class _ListOperations( + _BaseTextServiceRestTransport._BaseListOperations, TextServiceRestStub + ): + def __hash__(self): + return hash("TextServiceRestTransport.ListOperations") + + @staticmethod + def _get_response( + host, + metadata, + query_params, + session, + timeout, + transcoded_request, + body=None, + ): + uri = transcoded_request["uri"] + method = transcoded_request["method"] + headers = dict(metadata) + headers["Content-Type"] = "application/json" + response = getattr(session, method)( + "{host}{uri}".format(host=host, uri=uri), + timeout=timeout, + headers=headers, + params=rest_helpers.flatten_query_params(query_params, strict=True), + ) + return response + + def __call__( + self, + request: operations_pb2.ListOperationsRequest, + *, + retry: OptionalRetry = gapic_v1.method.DEFAULT, + timeout: Optional[float] = None, + metadata: Sequence[Tuple[str, Union[str, bytes]]] = (), + ) -> operations_pb2.ListOperationsResponse: + r"""Call the list operations method over HTTP. + + Args: + request (operations_pb2.ListOperationsRequest): + The request object for ListOperations method. + retry (google.api_core.retry.Retry): Designation of what errors, if any, + should be retried. + timeout (float): The timeout for this request. + metadata (Sequence[Tuple[str, Union[str, bytes]]]): Key/value pairs which should be + sent along with the request as metadata. Normally, each value must be of type `str`, + but for metadata keys ending with the suffix `-bin`, the corresponding values must + be of type `bytes`. + + Returns: + operations_pb2.ListOperationsResponse: Response from ListOperations method. + """ + + http_options = ( + _BaseTextServiceRestTransport._BaseListOperations._get_http_options() + ) + + request, metadata = self._interceptor.pre_list_operations(request, metadata) + transcoded_request = _BaseTextServiceRestTransport._BaseListOperations._get_transcoded_request( + http_options, request + ) + + # Jsonify the query params + query_params = _BaseTextServiceRestTransport._BaseListOperations._get_query_params_json( + transcoded_request + ) + + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + request_url = "{host}{uri}".format( + host=self._host, uri=transcoded_request["uri"] + ) + method = transcoded_request["method"] + try: + request_payload = json_format.MessageToJson(request) + except: + request_payload = None + http_request = { + "payload": request_payload, + "requestMethod": method, + "requestUrl": request_url, + "headers": dict(metadata), + } + _LOGGER.debug( + f"Sending request for google.ai.generativelanguage_v1alpha.TextServiceClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": "ListOperations", + "httpRequest": http_request, + "metadata": http_request["headers"], + }, + ) + + # Send the request + response = TextServiceRestTransport._ListOperations._get_response( + self._host, + metadata, + query_params, + self._session, + timeout, + transcoded_request, + ) + + # In case of error, raise the appropriate core_exceptions.GoogleAPICallError exception + # subclass. + if response.status_code >= 400: + raise core_exceptions.from_http_response(response) + + content = response.content.decode("utf-8") + resp = operations_pb2.ListOperationsResponse() + resp = json_format.Parse(content, resp) + resp = self._interceptor.post_list_operations(resp) + if CLIENT_LOGGING_SUPPORTED and _LOGGER.isEnabledFor( + logging.DEBUG + ): # pragma: NO COVER + try: + response_payload = json_format.MessageToJson(resp) + except: + response_payload = None + http_response = { + "payload": response_payload, + "headers": dict(response.headers), + "status": response.status_code, + } + _LOGGER.debug( + "Received response for google.ai.generativelanguage_v1alpha.TextServiceAsyncClient.ListOperations", + extra={ + "serviceName": "google.ai.generativelanguage.v1alpha.TextService", + "rpcName": "ListOperations", + "httpResponse": http_response, + "metadata": http_response["headers"], + }, + ) + return resp + + @property + def kind(self) -> str: + return "rest" + + def close(self): + self._session.close() + + +__all__ = ("TextServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/rest_base.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/rest_base.py new file mode 100644 index 000000000000..6139a0ec95a1 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/services/text_service/transports/rest_base.py @@ -0,0 +1,387 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json # type: ignore +import re +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from google.api_core import gapic_v1, path_template +from google.longrunning import operations_pb2 # type: ignore +from google.protobuf import json_format + +from google.ai.generativelanguage_v1alpha.types import text_service + +from .base import DEFAULT_CLIENT_INFO, TextServiceTransport + + +class _BaseTextServiceRestTransport(TextServiceTransport): + """Base REST backend transport for TextService. + + Note: This class is not meant to be used directly. Use its sync and + async sub-classes instead. + + This class defines the same methods as the primary client, so the + primary client can load the underlying transport implementation + and call it. + + It sends JSON representations of protocol buffers over HTTP/1.1 + """ + + def __init__( + self, + *, + host: str = "generativelanguage.googleapis.com", + credentials: Optional[Any] = None, + client_info: gapic_v1.client_info.ClientInfo = DEFAULT_CLIENT_INFO, + always_use_jwt_access: Optional[bool] = False, + url_scheme: str = "https", + api_audience: Optional[str] = None, + ) -> None: + """Instantiate the transport. + Args: + host (Optional[str]): + The hostname to connect to (default: 'generativelanguage.googleapis.com'). + credentials (Optional[Any]): The + authorization credentials to attach to requests. These + credentials identify the application to the service; if none + are specified, the client will attempt to ascertain the + credentials from the environment. + client_info (google.api_core.gapic_v1.client_info.ClientInfo): + The client info used to send a user-agent string along with + API requests. If ``None``, then default info will be used. + Generally, you only need to set this if you are developing + your own client library. + always_use_jwt_access (Optional[bool]): Whether self signed JWT should + be used for service account credentials. + url_scheme: the protocol scheme for the API endpoint. Normally + "https", but for testing or local servers, + "http" can be specified. + """ + # Run the base constructor + maybe_url_match = re.match("^(?Phttp(?:s)?://)?(?P.*)$", host) + if maybe_url_match is None: + raise ValueError( + f"Unexpected hostname structure: {host}" + ) # pragma: NO COVER + + url_match_items = maybe_url_match.groupdict() + + host = f"{url_scheme}://{host}" if not url_match_items["scheme"] else host + + super().__init__( + host=host, + credentials=credentials, + client_info=client_info, + always_use_jwt_access=always_use_jwt_access, + api_audience=api_audience, + ) + + class _BaseBatchEmbedText: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{model=models/*}:batchEmbedText", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = text_service.BatchEmbedTextRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseTextServiceRestTransport._BaseBatchEmbedText._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseCountTextTokens: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{model=models/*}:countTextTokens", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = text_service.CountTextTokensRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseTextServiceRestTransport._BaseCountTextTokens._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseEmbedText: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{model=models/*}:embedText", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = text_service.EmbedTextRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseTextServiceRestTransport._BaseEmbedText._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGenerateText: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + __REQUIRED_FIELDS_DEFAULT_VALUES: Dict[str, Any] = {} + + @classmethod + def _get_unset_required_fields(cls, message_dict): + return { + k: v + for k, v in cls.__REQUIRED_FIELDS_DEFAULT_VALUES.items() + if k not in message_dict + } + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "post", + "uri": "/v1alpha/{model=models/*}:generateText", + "body": "*", + }, + { + "method": "post", + "uri": "/v1alpha/{model=tunedModels/*}:generateText", + "body": "*", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + pb_request = text_service.GenerateTextRequest.pb(request) + transcoded_request = path_template.transcode(http_options, pb_request) + return transcoded_request + + @staticmethod + def _get_request_body_json(transcoded_request): + # Jsonify the request body + + body = json_format.MessageToJson( + transcoded_request["body"], use_integers_for_enums=True + ) + return body + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads( + json_format.MessageToJson( + transcoded_request["query_params"], + use_integers_for_enums=True, + ) + ) + query_params.update( + _BaseTextServiceRestTransport._BaseGenerateText._get_unset_required_fields( + query_params + ) + ) + + query_params["$alt"] = "json;enum-encoding=int" + return query_params + + class _BaseGetOperation: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=generatedFiles/*/operations/*}", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*/operations/*}", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + class _BaseListOperations: + def __hash__(self): # pragma: NO COVER + return NotImplementedError("__hash__ must be implemented.") + + @staticmethod + def _get_http_options(): + http_options: List[Dict[str, str]] = [ + { + "method": "get", + "uri": "/v1alpha/{name=tunedModels/*}/operations", + }, + { + "method": "get", + "uri": "/v1alpha/{name=models/*}/operations", + }, + ] + return http_options + + @staticmethod + def _get_transcoded_request(http_options, request): + request_kwargs = json_format.MessageToDict(request) + transcoded_request = path_template.transcode(http_options, **request_kwargs) + return transcoded_request + + @staticmethod + def _get_query_params_json(transcoded_request): + query_params = json.loads(json.dumps(transcoded_request["query_params"])) + return query_params + + +__all__ = ("_BaseTextServiceRestTransport",) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/__init__.py new file mode 100644 index 000000000000..bb351f14ec55 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/__init__.py @@ -0,0 +1,369 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .cache_service import ( + CreateCachedContentRequest, + DeleteCachedContentRequest, + GetCachedContentRequest, + ListCachedContentsRequest, + ListCachedContentsResponse, + UpdateCachedContentRequest, +) +from .cached_content import CachedContent +from .citation import CitationMetadata, CitationSource +from .content import ( + Blob, + CodeExecution, + CodeExecutionResult, + Content, + DynamicRetrievalConfig, + ExecutableCode, + FileData, + FunctionCall, + FunctionCallingConfig, + FunctionDeclaration, + FunctionResponse, + GoogleSearchRetrieval, + GroundingPassage, + GroundingPassages, + Part, + Schema, + Tool, + ToolConfig, + Type, +) +from .discuss_service import ( + CountMessageTokensRequest, + CountMessageTokensResponse, + Example, + GenerateMessageRequest, + GenerateMessageResponse, + Message, + MessagePrompt, +) +from .file import File, VideoMetadata +from .file_service import ( + CreateFileRequest, + CreateFileResponse, + DeleteFileRequest, + GetFileRequest, + ListFilesRequest, + ListFilesResponse, +) +from .generative_service import ( + AttributionSourceId, + BatchEmbedContentsRequest, + BatchEmbedContentsResponse, + BidiGenerateContentClientContent, + BidiGenerateContentClientMessage, + BidiGenerateContentRealtimeInput, + BidiGenerateContentServerContent, + BidiGenerateContentServerMessage, + BidiGenerateContentSetup, + BidiGenerateContentSetupComplete, + BidiGenerateContentToolCall, + BidiGenerateContentToolCallCancellation, + BidiGenerateContentToolResponse, + Candidate, + ContentEmbedding, + CountTokensRequest, + CountTokensResponse, + EmbedContentRequest, + EmbedContentResponse, + GenerateAnswerRequest, + GenerateAnswerResponse, + GenerateContentRequest, + GenerateContentResponse, + GenerationConfig, + GroundingAttribution, + GroundingChunk, + GroundingMetadata, + GroundingSupport, + LogprobsResult, + PrebuiltVoiceConfig, + RetrievalMetadata, + SearchEntryPoint, + Segment, + SemanticRetrieverConfig, + SpeechConfig, + TaskType, + VoiceConfig, +) +from .model import Model +from .model_service import ( + CreateTunedModelMetadata, + CreateTunedModelRequest, + DeleteTunedModelRequest, + GetModelRequest, + GetTunedModelRequest, + ListModelsRequest, + ListModelsResponse, + ListTunedModelsRequest, + ListTunedModelsResponse, + UpdateTunedModelRequest, +) +from .permission import Permission +from .permission_service import ( + CreatePermissionRequest, + DeletePermissionRequest, + GetPermissionRequest, + ListPermissionsRequest, + ListPermissionsResponse, + TransferOwnershipRequest, + TransferOwnershipResponse, + UpdatePermissionRequest, +) +from .prediction_service import PredictRequest, PredictResponse +from .retriever import ( + Chunk, + ChunkData, + Condition, + Corpus, + CustomMetadata, + Document, + MetadataFilter, + StringList, +) +from .retriever_service import ( + BatchCreateChunksRequest, + BatchCreateChunksResponse, + BatchDeleteChunksRequest, + BatchUpdateChunksRequest, + BatchUpdateChunksResponse, + CreateChunkRequest, + CreateCorpusRequest, + CreateDocumentRequest, + DeleteChunkRequest, + DeleteCorpusRequest, + DeleteDocumentRequest, + GetChunkRequest, + GetCorpusRequest, + GetDocumentRequest, + ListChunksRequest, + ListChunksResponse, + ListCorporaRequest, + ListCorporaResponse, + ListDocumentsRequest, + ListDocumentsResponse, + QueryCorpusRequest, + QueryCorpusResponse, + QueryDocumentRequest, + QueryDocumentResponse, + RelevantChunk, + UpdateChunkRequest, + UpdateCorpusRequest, + UpdateDocumentRequest, +) +from .safety import ( + ContentFilter, + HarmCategory, + SafetyFeedback, + SafetyRating, + SafetySetting, +) +from .text_service import ( + BatchEmbedTextRequest, + BatchEmbedTextResponse, + CountTextTokensRequest, + CountTextTokensResponse, + Embedding, + EmbedTextRequest, + EmbedTextResponse, + GenerateTextRequest, + GenerateTextResponse, + TextCompletion, + TextPrompt, +) +from .tuned_model import ( + Dataset, + Hyperparameters, + TunedModel, + TunedModelSource, + TuningContent, + TuningExample, + TuningExamples, + TuningMultiturnExample, + TuningPart, + TuningSnapshot, + TuningTask, +) + +__all__ = ( + "CreateCachedContentRequest", + "DeleteCachedContentRequest", + "GetCachedContentRequest", + "ListCachedContentsRequest", + "ListCachedContentsResponse", + "UpdateCachedContentRequest", + "CachedContent", + "CitationMetadata", + "CitationSource", + "Blob", + "CodeExecution", + "CodeExecutionResult", + "Content", + "DynamicRetrievalConfig", + "ExecutableCode", + "FileData", + "FunctionCall", + "FunctionCallingConfig", + "FunctionDeclaration", + "FunctionResponse", + "GoogleSearchRetrieval", + "GroundingPassage", + "GroundingPassages", + "Part", + "Schema", + "Tool", + "ToolConfig", + "Type", + "CountMessageTokensRequest", + "CountMessageTokensResponse", + "Example", + "GenerateMessageRequest", + "GenerateMessageResponse", + "Message", + "MessagePrompt", + "File", + "VideoMetadata", + "CreateFileRequest", + "CreateFileResponse", + "DeleteFileRequest", + "GetFileRequest", + "ListFilesRequest", + "ListFilesResponse", + "AttributionSourceId", + "BatchEmbedContentsRequest", + "BatchEmbedContentsResponse", + "BidiGenerateContentClientContent", + "BidiGenerateContentClientMessage", + "BidiGenerateContentRealtimeInput", + "BidiGenerateContentServerContent", + "BidiGenerateContentServerMessage", + "BidiGenerateContentSetup", + "BidiGenerateContentSetupComplete", + "BidiGenerateContentToolCall", + "BidiGenerateContentToolCallCancellation", + "BidiGenerateContentToolResponse", + "Candidate", + "ContentEmbedding", + "CountTokensRequest", + "CountTokensResponse", + "EmbedContentRequest", + "EmbedContentResponse", + "GenerateAnswerRequest", + "GenerateAnswerResponse", + "GenerateContentRequest", + "GenerateContentResponse", + "GenerationConfig", + "GroundingAttribution", + "GroundingChunk", + "GroundingMetadata", + "GroundingSupport", + "LogprobsResult", + "PrebuiltVoiceConfig", + "RetrievalMetadata", + "SearchEntryPoint", + "Segment", + "SemanticRetrieverConfig", + "SpeechConfig", + "VoiceConfig", + "TaskType", + "Model", + "CreateTunedModelMetadata", + "CreateTunedModelRequest", + "DeleteTunedModelRequest", + "GetModelRequest", + "GetTunedModelRequest", + "ListModelsRequest", + "ListModelsResponse", + "ListTunedModelsRequest", + "ListTunedModelsResponse", + "UpdateTunedModelRequest", + "Permission", + "CreatePermissionRequest", + "DeletePermissionRequest", + "GetPermissionRequest", + "ListPermissionsRequest", + "ListPermissionsResponse", + "TransferOwnershipRequest", + "TransferOwnershipResponse", + "UpdatePermissionRequest", + "PredictRequest", + "PredictResponse", + "Chunk", + "ChunkData", + "Condition", + "Corpus", + "CustomMetadata", + "Document", + "MetadataFilter", + "StringList", + "BatchCreateChunksRequest", + "BatchCreateChunksResponse", + "BatchDeleteChunksRequest", + "BatchUpdateChunksRequest", + "BatchUpdateChunksResponse", + "CreateChunkRequest", + "CreateCorpusRequest", + "CreateDocumentRequest", + "DeleteChunkRequest", + "DeleteCorpusRequest", + "DeleteDocumentRequest", + "GetChunkRequest", + "GetCorpusRequest", + "GetDocumentRequest", + "ListChunksRequest", + "ListChunksResponse", + "ListCorporaRequest", + "ListCorporaResponse", + "ListDocumentsRequest", + "ListDocumentsResponse", + "QueryCorpusRequest", + "QueryCorpusResponse", + "QueryDocumentRequest", + "QueryDocumentResponse", + "RelevantChunk", + "UpdateChunkRequest", + "UpdateCorpusRequest", + "UpdateDocumentRequest", + "ContentFilter", + "SafetyFeedback", + "SafetyRating", + "SafetySetting", + "HarmCategory", + "BatchEmbedTextRequest", + "BatchEmbedTextResponse", + "CountTextTokensRequest", + "CountTextTokensResponse", + "Embedding", + "EmbedTextRequest", + "EmbedTextResponse", + "GenerateTextRequest", + "GenerateTextResponse", + "TextCompletion", + "TextPrompt", + "Dataset", + "Hyperparameters", + "TunedModel", + "TunedModelSource", + "TuningContent", + "TuningExample", + "TuningExamples", + "TuningMultiturnExample", + "TuningPart", + "TuningSnapshot", + "TuningTask", +) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/cache_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/cache_service.py new file mode 100644 index 000000000000..74a6145eb924 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/cache_service.py @@ -0,0 +1,167 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +from google.protobuf import field_mask_pb2 # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import ( + cached_content as gag_cached_content, +) + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "ListCachedContentsRequest", + "ListCachedContentsResponse", + "CreateCachedContentRequest", + "GetCachedContentRequest", + "UpdateCachedContentRequest", + "DeleteCachedContentRequest", + }, +) + + +class ListCachedContentsRequest(proto.Message): + r"""Request to list CachedContents. + + Attributes: + page_size (int): + Optional. The maximum number of cached + contents to return. The service may return fewer + than this value. If unspecified, some default + (under maximum) number of items will be + returned. The maximum value is 1000; values + above 1000 will be coerced to 1000. + page_token (str): + Optional. A page token, received from a previous + ``ListCachedContents`` call. Provide this to retrieve the + subsequent page. + + When paginating, all other parameters provided to + ``ListCachedContents`` must match the call that provided the + page token. + """ + + page_size: int = proto.Field( + proto.INT32, + number=1, + ) + page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +class ListCachedContentsResponse(proto.Message): + r"""Response with CachedContents list. + + Attributes: + cached_contents (MutableSequence[google.ai.generativelanguage_v1alpha.types.CachedContent]): + List of cached contents. + next_page_token (str): + A token, which can be sent as ``page_token`` to retrieve the + next page. If this field is omitted, there are no subsequent + pages. + """ + + @property + def raw_page(self): + return self + + cached_contents: MutableSequence[ + gag_cached_content.CachedContent + ] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=gag_cached_content.CachedContent, + ) + next_page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +class CreateCachedContentRequest(proto.Message): + r"""Request to create CachedContent. + + Attributes: + cached_content (google.ai.generativelanguage_v1alpha.types.CachedContent): + Required. The cached content to create. + """ + + cached_content: gag_cached_content.CachedContent = proto.Field( + proto.MESSAGE, + number=1, + message=gag_cached_content.CachedContent, + ) + + +class GetCachedContentRequest(proto.Message): + r"""Request to read CachedContent. + + Attributes: + name (str): + Required. The resource name referring to the content cache + entry. Format: ``cachedContents/{id}`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class UpdateCachedContentRequest(proto.Message): + r"""Request to update CachedContent. + + Attributes: + cached_content (google.ai.generativelanguage_v1alpha.types.CachedContent): + Required. The content cache entry to update + update_mask (google.protobuf.field_mask_pb2.FieldMask): + The list of fields to update. + """ + + cached_content: gag_cached_content.CachedContent = proto.Field( + proto.MESSAGE, + number=1, + message=gag_cached_content.CachedContent, + ) + update_mask: field_mask_pb2.FieldMask = proto.Field( + proto.MESSAGE, + number=2, + message=field_mask_pb2.FieldMask, + ) + + +class DeleteCachedContentRequest(proto.Message): + r"""Request to delete CachedContent. + + Attributes: + name (str): + Required. The resource name referring to the content cache + entry Format: ``cachedContents/{id}`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/cached_content.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/cached_content.py new file mode 100644 index 000000000000..3eeed986fcc9 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/cached_content.py @@ -0,0 +1,182 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +from google.protobuf import duration_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import content + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "CachedContent", + }, +) + + +class CachedContent(proto.Message): + r"""Content that has been preprocessed and can be used in + subsequent request to GenerativeService. + + Cached content can be only used with model it was created for. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + expire_time (google.protobuf.timestamp_pb2.Timestamp): + Timestamp in UTC of when this resource is considered + expired. This is *always* provided on output, regardless of + what was sent on input. + + This field is a member of `oneof`_ ``expiration``. + ttl (google.protobuf.duration_pb2.Duration): + Input only. New TTL for this resource, input + only. + + This field is a member of `oneof`_ ``expiration``. + name (str): + Optional. Identifier. The resource name referring to the + cached content. Format: ``cachedContents/{id}`` + + This field is a member of `oneof`_ ``_name``. + display_name (str): + Optional. Immutable. The user-generated + meaningful display name of the cached content. + Maximum 128 Unicode characters. + + This field is a member of `oneof`_ ``_display_name``. + model (str): + Required. Immutable. The name of the ``Model`` to use for + cached content Format: ``models/{model}`` + + This field is a member of `oneof`_ ``_model``. + system_instruction (google.ai.generativelanguage_v1alpha.types.Content): + Optional. Input only. Immutable. Developer + set system instruction. Currently text only. + + This field is a member of `oneof`_ ``_system_instruction``. + contents (MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]): + Optional. Input only. Immutable. The content + to cache. + tools (MutableSequence[google.ai.generativelanguage_v1alpha.types.Tool]): + Optional. Input only. Immutable. A list of ``Tools`` the + model may use to generate the next response + tool_config (google.ai.generativelanguage_v1alpha.types.ToolConfig): + Optional. Input only. Immutable. Tool config. + This config is shared for all tools. + + This field is a member of `oneof`_ ``_tool_config``. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. Creation time of the cache + entry. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. When the cache entry was last + updated in UTC time. + usage_metadata (google.ai.generativelanguage_v1alpha.types.CachedContent.UsageMetadata): + Output only. Metadata on the usage of the + cached content. + """ + + class UsageMetadata(proto.Message): + r"""Metadata on the usage of the cached content. + + Attributes: + total_token_count (int): + Total number of tokens that the cached + content consumes. + """ + + total_token_count: int = proto.Field( + proto.INT32, + number=1, + ) + + expire_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=9, + oneof="expiration", + message=timestamp_pb2.Timestamp, + ) + ttl: duration_pb2.Duration = proto.Field( + proto.MESSAGE, + number=10, + oneof="expiration", + message=duration_pb2.Duration, + ) + name: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + display_name: str = proto.Field( + proto.STRING, + number=11, + optional=True, + ) + model: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + system_instruction: content.Content = proto.Field( + proto.MESSAGE, + number=3, + optional=True, + message=content.Content, + ) + contents: MutableSequence[content.Content] = proto.RepeatedField( + proto.MESSAGE, + number=4, + message=content.Content, + ) + tools: MutableSequence[content.Tool] = proto.RepeatedField( + proto.MESSAGE, + number=5, + message=content.Tool, + ) + tool_config: content.ToolConfig = proto.Field( + proto.MESSAGE, + number=6, + optional=True, + message=content.ToolConfig, + ) + create_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=7, + message=timestamp_pb2.Timestamp, + ) + update_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=8, + message=timestamp_pb2.Timestamp, + ) + usage_metadata: UsageMetadata = proto.Field( + proto.MESSAGE, + number=12, + message=UsageMetadata, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/citation.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/citation.py new file mode 100644 index 000000000000..837bd3f4b2ea --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/citation.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "CitationMetadata", + "CitationSource", + }, +) + + +class CitationMetadata(proto.Message): + r"""A collection of source attributions for a piece of content. + + Attributes: + citation_sources (MutableSequence[google.ai.generativelanguage_v1alpha.types.CitationSource]): + Citations to sources for a specific response. + """ + + citation_sources: MutableSequence["CitationSource"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="CitationSource", + ) + + +class CitationSource(proto.Message): + r"""A citation to a source for a portion of a specific response. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + start_index (int): + Optional. Start of segment of the response + that is attributed to this source. + + Index indicates the start of the segment, + measured in bytes. + + This field is a member of `oneof`_ ``_start_index``. + end_index (int): + Optional. End of the attributed segment, + exclusive. + + This field is a member of `oneof`_ ``_end_index``. + uri (str): + Optional. URI that is attributed as a source + for a portion of the text. + + This field is a member of `oneof`_ ``_uri``. + license_ (str): + Optional. License for the GitHub project that + is attributed as a source for segment. + + License info is required for code citations. + + This field is a member of `oneof`_ ``_license``. + """ + + start_index: int = proto.Field( + proto.INT32, + number=1, + optional=True, + ) + end_index: int = proto.Field( + proto.INT32, + number=2, + optional=True, + ) + uri: str = proto.Field( + proto.STRING, + number=3, + optional=True, + ) + license_: str = proto.Field( + proto.STRING, + number=4, + optional=True, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/content.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/content.py new file mode 100644 index 000000000000..d8ec135a889f --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/content.py @@ -0,0 +1,819 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +from google.protobuf import struct_pb2 # type: ignore +import proto # type: ignore + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "Type", + "Content", + "Part", + "Blob", + "FileData", + "ExecutableCode", + "CodeExecutionResult", + "Tool", + "GoogleSearchRetrieval", + "DynamicRetrievalConfig", + "CodeExecution", + "ToolConfig", + "FunctionCallingConfig", + "FunctionDeclaration", + "FunctionCall", + "FunctionResponse", + "Schema", + "GroundingPassage", + "GroundingPassages", + }, +) + + +class Type(proto.Enum): + r"""Type contains the list of OpenAPI data types as defined by + https://spec.openapis.org/oas/v3.0.3#data-types + + Values: + TYPE_UNSPECIFIED (0): + Not specified, should not be used. + STRING (1): + String type. + NUMBER (2): + Number type. + INTEGER (3): + Integer type. + BOOLEAN (4): + Boolean type. + ARRAY (5): + Array type. + OBJECT (6): + Object type. + """ + TYPE_UNSPECIFIED = 0 + STRING = 1 + NUMBER = 2 + INTEGER = 3 + BOOLEAN = 4 + ARRAY = 5 + OBJECT = 6 + + +class Content(proto.Message): + r"""The base structured datatype containing multi-part content of a + message. + + A ``Content`` includes a ``role`` field designating the producer of + the ``Content`` and a ``parts`` field containing multi-part data + that contains the content of the message turn. + + Attributes: + parts (MutableSequence[google.ai.generativelanguage_v1alpha.types.Part]): + Ordered ``Parts`` that constitute a single message. Parts + may have different MIME types. + role (str): + Optional. The producer of the content. Must + be either 'user' or 'model'. + Useful to set for multi-turn conversations, + otherwise can be left blank or unset. + """ + + parts: MutableSequence["Part"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="Part", + ) + role: str = proto.Field( + proto.STRING, + number=2, + ) + + +class Part(proto.Message): + r"""A datatype containing media that is part of a multi-part ``Content`` + message. + + A ``Part`` consists of data which has an associated datatype. A + ``Part`` can only contain one of the accepted types in + ``Part.data``. + + A ``Part`` must have a fixed IANA MIME type identifying the type and + subtype of the media if the ``inline_data`` field is filled with raw + bytes. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + text (str): + Inline text. + + This field is a member of `oneof`_ ``data``. + inline_data (google.ai.generativelanguage_v1alpha.types.Blob): + Inline media bytes. + + This field is a member of `oneof`_ ``data``. + function_call (google.ai.generativelanguage_v1alpha.types.FunctionCall): + A predicted ``FunctionCall`` returned from the model that + contains a string representing the + ``FunctionDeclaration.name`` with the arguments and their + values. + + This field is a member of `oneof`_ ``data``. + function_response (google.ai.generativelanguage_v1alpha.types.FunctionResponse): + The result output of a ``FunctionCall`` that contains a + string representing the ``FunctionDeclaration.name`` and a + structured JSON object containing any output from the + function is used as context to the model. + + This field is a member of `oneof`_ ``data``. + file_data (google.ai.generativelanguage_v1alpha.types.FileData): + URI based data. + + This field is a member of `oneof`_ ``data``. + executable_code (google.ai.generativelanguage_v1alpha.types.ExecutableCode): + Code generated by the model that is meant to + be executed. + + This field is a member of `oneof`_ ``data``. + code_execution_result (google.ai.generativelanguage_v1alpha.types.CodeExecutionResult): + Result of executing the ``ExecutableCode``. + + This field is a member of `oneof`_ ``data``. + """ + + text: str = proto.Field( + proto.STRING, + number=2, + oneof="data", + ) + inline_data: "Blob" = proto.Field( + proto.MESSAGE, + number=3, + oneof="data", + message="Blob", + ) + function_call: "FunctionCall" = proto.Field( + proto.MESSAGE, + number=4, + oneof="data", + message="FunctionCall", + ) + function_response: "FunctionResponse" = proto.Field( + proto.MESSAGE, + number=5, + oneof="data", + message="FunctionResponse", + ) + file_data: "FileData" = proto.Field( + proto.MESSAGE, + number=6, + oneof="data", + message="FileData", + ) + executable_code: "ExecutableCode" = proto.Field( + proto.MESSAGE, + number=9, + oneof="data", + message="ExecutableCode", + ) + code_execution_result: "CodeExecutionResult" = proto.Field( + proto.MESSAGE, + number=10, + oneof="data", + message="CodeExecutionResult", + ) + + +class Blob(proto.Message): + r"""Raw media bytes. + + Text should not be sent as raw bytes, use the 'text' field. + + Attributes: + mime_type (str): + The IANA standard MIME type of the source data. Examples: + + - image/png + - image/jpeg If an unsupported MIME type is provided, an + error will be returned. For a complete list of supported + types, see `Supported file + formats `__. + data (bytes): + Raw bytes for media formats. + """ + + mime_type: str = proto.Field( + proto.STRING, + number=1, + ) + data: bytes = proto.Field( + proto.BYTES, + number=2, + ) + + +class FileData(proto.Message): + r"""URI based data. + + Attributes: + mime_type (str): + Optional. The IANA standard MIME type of the + source data. + file_uri (str): + Required. URI. + """ + + mime_type: str = proto.Field( + proto.STRING, + number=1, + ) + file_uri: str = proto.Field( + proto.STRING, + number=2, + ) + + +class ExecutableCode(proto.Message): + r"""Code generated by the model that is meant to be executed, and the + result returned to the model. + + Only generated when using the ``CodeExecution`` tool, in which the + code will be automatically executed, and a corresponding + ``CodeExecutionResult`` will also be generated. + + Attributes: + language (google.ai.generativelanguage_v1alpha.types.ExecutableCode.Language): + Required. Programming language of the ``code``. + code (str): + Required. The code to be executed. + """ + + class Language(proto.Enum): + r"""Supported programming languages for the generated code. + + Values: + LANGUAGE_UNSPECIFIED (0): + Unspecified language. This value should not + be used. + PYTHON (1): + Python >= 3.10, with numpy and simpy + available. + """ + LANGUAGE_UNSPECIFIED = 0 + PYTHON = 1 + + language: Language = proto.Field( + proto.ENUM, + number=1, + enum=Language, + ) + code: str = proto.Field( + proto.STRING, + number=2, + ) + + +class CodeExecutionResult(proto.Message): + r"""Result of executing the ``ExecutableCode``. + + Only generated when using the ``CodeExecution``, and always follows + a ``part`` containing the ``ExecutableCode``. + + Attributes: + outcome (google.ai.generativelanguage_v1alpha.types.CodeExecutionResult.Outcome): + Required. Outcome of the code execution. + output (str): + Optional. Contains stdout when code execution + is successful, stderr or other description + otherwise. + """ + + class Outcome(proto.Enum): + r"""Enumeration of possible outcomes of the code execution. + + Values: + OUTCOME_UNSPECIFIED (0): + Unspecified status. This value should not be + used. + OUTCOME_OK (1): + Code execution completed successfully. + OUTCOME_FAILED (2): + Code execution finished but with a failure. ``stderr`` + should contain the reason. + OUTCOME_DEADLINE_EXCEEDED (3): + Code execution ran for too long, and was + cancelled. There may or may not be a partial + output present. + """ + OUTCOME_UNSPECIFIED = 0 + OUTCOME_OK = 1 + OUTCOME_FAILED = 2 + OUTCOME_DEADLINE_EXCEEDED = 3 + + outcome: Outcome = proto.Field( + proto.ENUM, + number=1, + enum=Outcome, + ) + output: str = proto.Field( + proto.STRING, + number=2, + ) + + +class Tool(proto.Message): + r"""Tool details that the model may use to generate response. + + A ``Tool`` is a piece of code that enables the system to interact + with external systems to perform an action, or set of actions, + outside of knowledge and scope of the model. + + Attributes: + function_declarations (MutableSequence[google.ai.generativelanguage_v1alpha.types.FunctionDeclaration]): + Optional. A list of ``FunctionDeclarations`` available to + the model that can be used for function calling. + + The model or system does not execute the function. Instead + the defined function may be returned as a + [FunctionCall][google.ai.generativelanguage.v1alpha.Part.function_call] + with arguments to the client side for execution. The model + may decide to call a subset of these functions by populating + [FunctionCall][google.ai.generativelanguage.v1alpha.Part.function_call] + in the response. The next conversation turn may contain a + [FunctionResponse][google.ai.generativelanguage.v1alpha.Part.function_response] + with the + [Content.role][google.ai.generativelanguage.v1alpha.Content.role] + "function" generation context for the next model turn. + google_search_retrieval (google.ai.generativelanguage_v1alpha.types.GoogleSearchRetrieval): + Optional. Retrieval tool that is powered by + Google search. + code_execution (google.ai.generativelanguage_v1alpha.types.CodeExecution): + Optional. Enables the model to execute code + as part of generation. + google_search (google.ai.generativelanguage_v1alpha.types.Tool.GoogleSearch): + Optional. GoogleSearch tool type. + Tool to support Google Search in Model. Powered + by Google. + """ + + class GoogleSearch(proto.Message): + r"""GoogleSearch tool type. + Tool to support Google Search in Model. Powered by Google. + + """ + + function_declarations: MutableSequence["FunctionDeclaration"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="FunctionDeclaration", + ) + google_search_retrieval: "GoogleSearchRetrieval" = proto.Field( + proto.MESSAGE, + number=2, + message="GoogleSearchRetrieval", + ) + code_execution: "CodeExecution" = proto.Field( + proto.MESSAGE, + number=3, + message="CodeExecution", + ) + google_search: GoogleSearch = proto.Field( + proto.MESSAGE, + number=4, + message=GoogleSearch, + ) + + +class GoogleSearchRetrieval(proto.Message): + r"""Tool to retrieve public web data for grounding, powered by + Google. + + Attributes: + dynamic_retrieval_config (google.ai.generativelanguage_v1alpha.types.DynamicRetrievalConfig): + Specifies the dynamic retrieval configuration + for the given source. + """ + + dynamic_retrieval_config: "DynamicRetrievalConfig" = proto.Field( + proto.MESSAGE, + number=1, + message="DynamicRetrievalConfig", + ) + + +class DynamicRetrievalConfig(proto.Message): + r"""Describes the options to customize dynamic retrieval. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + mode (google.ai.generativelanguage_v1alpha.types.DynamicRetrievalConfig.Mode): + The mode of the predictor to be used in + dynamic retrieval. + dynamic_threshold (float): + The threshold to be used in dynamic + retrieval. If not set, a system default value is + used. + + This field is a member of `oneof`_ ``_dynamic_threshold``. + """ + + class Mode(proto.Enum): + r"""The mode of the predictor to be used in dynamic retrieval. + + Values: + MODE_UNSPECIFIED (0): + Always trigger retrieval. + MODE_DYNAMIC (1): + Run retrieval only when system decides it is + necessary. + """ + MODE_UNSPECIFIED = 0 + MODE_DYNAMIC = 1 + + mode: Mode = proto.Field( + proto.ENUM, + number=1, + enum=Mode, + ) + dynamic_threshold: float = proto.Field( + proto.FLOAT, + number=2, + optional=True, + ) + + +class CodeExecution(proto.Message): + r"""Tool that executes code generated by the model, and automatically + returns the result to the model. + + See also ``ExecutableCode`` and ``CodeExecutionResult`` which are + only generated when using this tool. + + """ + + +class ToolConfig(proto.Message): + r"""The Tool configuration containing parameters for specifying ``Tool`` + use in the request. + + Attributes: + function_calling_config (google.ai.generativelanguage_v1alpha.types.FunctionCallingConfig): + Optional. Function calling config. + """ + + function_calling_config: "FunctionCallingConfig" = proto.Field( + proto.MESSAGE, + number=1, + message="FunctionCallingConfig", + ) + + +class FunctionCallingConfig(proto.Message): + r"""Configuration for specifying function calling behavior. + + Attributes: + mode (google.ai.generativelanguage_v1alpha.types.FunctionCallingConfig.Mode): + Optional. Specifies the mode in which + function calling should execute. If unspecified, + the default value will be set to AUTO. + allowed_function_names (MutableSequence[str]): + Optional. A set of function names that, when provided, + limits the functions the model will call. + + This should only be set when the Mode is ANY. Function names + should match [FunctionDeclaration.name]. With mode set to + ANY, model will predict a function call from the set of + function names provided. + """ + + class Mode(proto.Enum): + r"""Defines the execution behavior for function calling by + defining the execution mode. + + Values: + MODE_UNSPECIFIED (0): + Unspecified function calling mode. This value + should not be used. + AUTO (1): + Default model behavior, model decides to + predict either a function call or a natural + language response. + ANY (2): + Model is constrained to always predicting a function call + only. If "allowed_function_names" are set, the predicted + function call will be limited to any one of + "allowed_function_names", else the predicted function call + will be any one of the provided "function_declarations". + NONE (3): + Model will not predict any function call. + Model behavior is same as when not passing any + function declarations. + """ + MODE_UNSPECIFIED = 0 + AUTO = 1 + ANY = 2 + NONE = 3 + + mode: Mode = proto.Field( + proto.ENUM, + number=1, + enum=Mode, + ) + allowed_function_names: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=2, + ) + + +class FunctionDeclaration(proto.Message): + r"""Structured representation of a function declaration as defined by + the `OpenAPI 3.03 + specification `__. Included in + this declaration are the function name and parameters. This + FunctionDeclaration is a representation of a block of code that can + be used as a ``Tool`` by the model and executed by the client. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + name (str): + Required. The name of the function. + Must be a-z, A-Z, 0-9, or contain underscores + and dashes, with a maximum length of 63. + description (str): + Required. A brief description of the + function. + parameters (google.ai.generativelanguage_v1alpha.types.Schema): + Optional. Describes the parameters to this + function. Reflects the Open API 3.03 Parameter + Object string Key: the name of the parameter. + Parameter names are case sensitive. Schema + Value: the Schema defining the type used for the + parameter. + + This field is a member of `oneof`_ ``_parameters``. + response (google.ai.generativelanguage_v1alpha.types.Schema): + Optional. Describes the output from this + function in JSON Schema format. Reflects the + Open API 3.03 Response Object. The Schema + defines the type used for the response value of + the function. + + This field is a member of `oneof`_ ``_response``. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + description: str = proto.Field( + proto.STRING, + number=2, + ) + parameters: "Schema" = proto.Field( + proto.MESSAGE, + number=3, + optional=True, + message="Schema", + ) + response: "Schema" = proto.Field( + proto.MESSAGE, + number=4, + optional=True, + message="Schema", + ) + + +class FunctionCall(proto.Message): + r"""A predicted ``FunctionCall`` returned from the model that contains a + string representing the ``FunctionDeclaration.name`` with the + arguments and their values. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + id (str): + Optional. The unique id of the function call. If populated, + the client to execute the ``function_call`` and return the + response with the matching ``id``. + name (str): + Required. The name of the function to call. + Must be a-z, A-Z, 0-9, or contain underscores + and dashes, with a maximum length of 63. + args (google.protobuf.struct_pb2.Struct): + Optional. The function parameters and values + in JSON object format. + + This field is a member of `oneof`_ ``_args``. + """ + + id: str = proto.Field( + proto.STRING, + number=3, + ) + name: str = proto.Field( + proto.STRING, + number=1, + ) + args: struct_pb2.Struct = proto.Field( + proto.MESSAGE, + number=2, + optional=True, + message=struct_pb2.Struct, + ) + + +class FunctionResponse(proto.Message): + r"""The result output from a ``FunctionCall`` that contains a string + representing the ``FunctionDeclaration.name`` and a structured JSON + object containing any output from the function is used as context to + the model. This should contain the result of a\ ``FunctionCall`` + made based on model prediction. + + Attributes: + id (str): + Optional. The id of the function call this response is for. + Populated by the client to match the corresponding function + call ``id``. + name (str): + Required. The name of the function to call. + Must be a-z, A-Z, 0-9, or contain underscores + and dashes, with a maximum length of 63. + response (google.protobuf.struct_pb2.Struct): + Required. The function response in JSON + object format. + """ + + id: str = proto.Field( + proto.STRING, + number=3, + ) + name: str = proto.Field( + proto.STRING, + number=1, + ) + response: struct_pb2.Struct = proto.Field( + proto.MESSAGE, + number=2, + message=struct_pb2.Struct, + ) + + +class Schema(proto.Message): + r"""The ``Schema`` object allows the definition of input and output data + types. These types can be objects, but also primitives and arrays. + Represents a select subset of an `OpenAPI 3.0 schema + object `__. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + type_ (google.ai.generativelanguage_v1alpha.types.Type): + Required. Data type. + format_ (str): + Optional. The format of the data. This is + used only for primitive datatypes. Supported + formats: + + for NUMBER type: float, double + for INTEGER type: int32, int64 + for STRING type: enum + description (str): + Optional. A brief description of the + parameter. This could contain examples of use. + Parameter description may be formatted as + Markdown. + nullable (bool): + Optional. Indicates if the value may be null. + enum (MutableSequence[str]): + Optional. Possible values of the element of Type.STRING with + enum format. For example we can define an Enum Direction as + : {type:STRING, format:enum, enum:["EAST", NORTH", "SOUTH", + "WEST"]} + items (google.ai.generativelanguage_v1alpha.types.Schema): + Optional. Schema of the elements of + Type.ARRAY. + + This field is a member of `oneof`_ ``_items``. + max_items (int): + Optional. Maximum number of the elements for + Type.ARRAY. + min_items (int): + Optional. Minimum number of the elements for + Type.ARRAY. + properties (MutableMapping[str, google.ai.generativelanguage_v1alpha.types.Schema]): + Optional. Properties of Type.OBJECT. + required (MutableSequence[str]): + Optional. Required properties of Type.OBJECT. + """ + + type_: "Type" = proto.Field( + proto.ENUM, + number=1, + enum="Type", + ) + format_: str = proto.Field( + proto.STRING, + number=2, + ) + description: str = proto.Field( + proto.STRING, + number=3, + ) + nullable: bool = proto.Field( + proto.BOOL, + number=4, + ) + enum: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=5, + ) + items: "Schema" = proto.Field( + proto.MESSAGE, + number=6, + optional=True, + message="Schema", + ) + max_items: int = proto.Field( + proto.INT64, + number=21, + ) + min_items: int = proto.Field( + proto.INT64, + number=22, + ) + properties: MutableMapping[str, "Schema"] = proto.MapField( + proto.STRING, + proto.MESSAGE, + number=7, + message="Schema", + ) + required: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=8, + ) + + +class GroundingPassage(proto.Message): + r"""Passage included inline with a grounding configuration. + + Attributes: + id (str): + Identifier for the passage for attributing + this passage in grounded answers. + content (google.ai.generativelanguage_v1alpha.types.Content): + Content of the passage. + """ + + id: str = proto.Field( + proto.STRING, + number=1, + ) + content: "Content" = proto.Field( + proto.MESSAGE, + number=2, + message="Content", + ) + + +class GroundingPassages(proto.Message): + r"""A repeated list of passages. + + Attributes: + passages (MutableSequence[google.ai.generativelanguage_v1alpha.types.GroundingPassage]): + List of passages. + """ + + passages: MutableSequence["GroundingPassage"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="GroundingPassage", + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/discuss_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/discuss_service.py new file mode 100644 index 000000000000..2ce5d18dc6b1 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/discuss_service.py @@ -0,0 +1,356 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import citation, safety + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "GenerateMessageRequest", + "GenerateMessageResponse", + "Message", + "MessagePrompt", + "Example", + "CountMessageTokensRequest", + "CountMessageTokensResponse", + }, +) + + +class GenerateMessageRequest(proto.Message): + r"""Request to generate a message response from the model. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + model (str): + Required. The name of the model to use. + + Format: ``name=models/{model}``. + prompt (google.ai.generativelanguage_v1alpha.types.MessagePrompt): + Required. The structured textual input given + to the model as a prompt. + Given a + prompt, the model will return what it predicts + is the next message in the discussion. + temperature (float): + Optional. Controls the randomness of the output. + + Values can range over ``[0.0,1.0]``, inclusive. A value + closer to ``1.0`` will produce responses that are more + varied, while a value closer to ``0.0`` will typically + result in less surprising responses from the model. + + This field is a member of `oneof`_ ``_temperature``. + candidate_count (int): + Optional. The number of generated response messages to + return. + + This value must be between ``[1, 8]``, inclusive. If unset, + this will default to ``1``. + + This field is a member of `oneof`_ ``_candidate_count``. + top_p (float): + Optional. The maximum cumulative probability of tokens to + consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Nucleus sampling considers the smallest set of tokens whose + probability sum is at least ``top_p``. + + This field is a member of `oneof`_ ``_top_p``. + top_k (int): + Optional. The maximum number of tokens to consider when + sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most probable + tokens. + + This field is a member of `oneof`_ ``_top_k``. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + prompt: "MessagePrompt" = proto.Field( + proto.MESSAGE, + number=2, + message="MessagePrompt", + ) + temperature: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + candidate_count: int = proto.Field( + proto.INT32, + number=4, + optional=True, + ) + top_p: float = proto.Field( + proto.FLOAT, + number=5, + optional=True, + ) + top_k: int = proto.Field( + proto.INT32, + number=6, + optional=True, + ) + + +class GenerateMessageResponse(proto.Message): + r"""The response from the model. + + This includes candidate messages and + conversation history in the form of chronologically-ordered + messages. + + Attributes: + candidates (MutableSequence[google.ai.generativelanguage_v1alpha.types.Message]): + Candidate response messages from the model. + messages (MutableSequence[google.ai.generativelanguage_v1alpha.types.Message]): + The conversation history used by the model. + filters (MutableSequence[google.ai.generativelanguage_v1alpha.types.ContentFilter]): + A set of content filtering metadata for the prompt and + response text. + + This indicates which ``SafetyCategory``\ (s) blocked a + candidate from this response, the lowest ``HarmProbability`` + that triggered a block, and the HarmThreshold setting for + that category. + """ + + candidates: MutableSequence["Message"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="Message", + ) + messages: MutableSequence["Message"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Message", + ) + filters: MutableSequence[safety.ContentFilter] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message=safety.ContentFilter, + ) + + +class Message(proto.Message): + r"""The base unit of structured text. + + A ``Message`` includes an ``author`` and the ``content`` of the + ``Message``. + + The ``author`` is used to tag messages when they are fed to the + model as text. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + author (str): + Optional. The author of this Message. + + This serves as a key for tagging + the content of this Message when it is fed to + the model as text. + + The author can be any alphanumeric string. + content (str): + Required. The text content of the structured ``Message``. + citation_metadata (google.ai.generativelanguage_v1alpha.types.CitationMetadata): + Output only. Citation information for model-generated + ``content`` in this ``Message``. + + If this ``Message`` was generated as output from the model, + this field may be populated with attribution information for + any text included in the ``content``. This field is used + only on output. + + This field is a member of `oneof`_ ``_citation_metadata``. + """ + + author: str = proto.Field( + proto.STRING, + number=1, + ) + content: str = proto.Field( + proto.STRING, + number=2, + ) + citation_metadata: citation.CitationMetadata = proto.Field( + proto.MESSAGE, + number=3, + optional=True, + message=citation.CitationMetadata, + ) + + +class MessagePrompt(proto.Message): + r"""All of the structured input text passed to the model as a prompt. + + A ``MessagePrompt`` contains a structured set of fields that provide + context for the conversation, examples of user input/model output + message pairs that prime the model to respond in different ways, and + the conversation history or list of messages representing the + alternating turns of the conversation between the user and the + model. + + Attributes: + context (str): + Optional. Text that should be provided to the model first to + ground the response. + + If not empty, this ``context`` will be given to the model + first before the ``examples`` and ``messages``. When using a + ``context`` be sure to provide it with every request to + maintain continuity. + + This field can be a description of your prompt to the model + to help provide context and guide the responses. Examples: + "Translate the phrase from English to French." or "Given a + statement, classify the sentiment as happy, sad or neutral." + + Anything included in this field will take precedence over + message history if the total input size exceeds the model's + ``input_token_limit`` and the input request is truncated. + examples (MutableSequence[google.ai.generativelanguage_v1alpha.types.Example]): + Optional. Examples of what the model should generate. + + This includes both user input and the response that the + model should emulate. + + These ``examples`` are treated identically to conversation + messages except that they take precedence over the history + in ``messages``: If the total input size exceeds the model's + ``input_token_limit`` the input will be truncated. Items + will be dropped from ``messages`` before ``examples``. + messages (MutableSequence[google.ai.generativelanguage_v1alpha.types.Message]): + Required. A snapshot of the recent conversation history + sorted chronologically. + + Turns alternate between two authors. + + If the total input size exceeds the model's + ``input_token_limit`` the input will be truncated: The + oldest items will be dropped from ``messages``. + """ + + context: str = proto.Field( + proto.STRING, + number=1, + ) + examples: MutableSequence["Example"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Example", + ) + messages: MutableSequence["Message"] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="Message", + ) + + +class Example(proto.Message): + r"""An input/output example used to instruct the Model. + + It demonstrates how the model should respond or format its + response. + + Attributes: + input (google.ai.generativelanguage_v1alpha.types.Message): + Required. An example of an input ``Message`` from the user. + output (google.ai.generativelanguage_v1alpha.types.Message): + Required. An example of what the model should + output given the input. + """ + + input: "Message" = proto.Field( + proto.MESSAGE, + number=1, + message="Message", + ) + output: "Message" = proto.Field( + proto.MESSAGE, + number=2, + message="Message", + ) + + +class CountMessageTokensRequest(proto.Message): + r"""Counts the number of tokens in the ``prompt`` sent to a model. + + Models may tokenize text differently, so each model may return a + different ``token_count``. + + Attributes: + model (str): + Required. The model's resource name. This serves as an ID + for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + prompt (google.ai.generativelanguage_v1alpha.types.MessagePrompt): + Required. The prompt, whose token count is to + be returned. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + prompt: "MessagePrompt" = proto.Field( + proto.MESSAGE, + number=2, + message="MessagePrompt", + ) + + +class CountMessageTokensResponse(proto.Message): + r"""A response from ``CountMessageTokens``. + + It returns the model's ``token_count`` for the ``prompt``. + + Attributes: + token_count (int): + The number of tokens that the ``model`` tokenizes the + ``prompt`` into. + + Always non-negative. + """ + + token_count: int = proto.Field( + proto.INT32, + number=1, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/file.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/file.py new file mode 100644 index 000000000000..8a3fdac30930 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/file.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +from google.protobuf import duration_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from google.rpc import status_pb2 # type: ignore +import proto # type: ignore + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "File", + "VideoMetadata", + }, +) + + +class File(proto.Message): + r"""A file uploaded to the API. + Next ID: 15 + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + video_metadata (google.ai.generativelanguage_v1alpha.types.VideoMetadata): + Output only. Metadata for a video. + + This field is a member of `oneof`_ ``metadata``. + name (str): + Immutable. Identifier. The ``File`` resource name. The ID + (name excluding the "files/" prefix) can contain up to 40 + characters that are lowercase alphanumeric or dashes (-). + The ID cannot start or end with a dash. If the name is empty + on create, a unique name will be generated. Example: + ``files/123-456`` + display_name (str): + Optional. The human-readable display name for the ``File``. + The display name must be no more than 512 characters in + length, including spaces. Example: "Welcome Image". + mime_type (str): + Output only. MIME type of the file. + size_bytes (int): + Output only. Size of the file in bytes. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The timestamp of when the ``File`` was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The timestamp of when the ``File`` was last + updated. + expiration_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The timestamp of when the ``File`` will be + deleted. Only set if the ``File`` is scheduled to expire. + sha256_hash (bytes): + Output only. SHA-256 hash of the uploaded + bytes. + uri (str): + Output only. The uri of the ``File``. + state (google.ai.generativelanguage_v1alpha.types.File.State): + Output only. Processing state of the File. + error (google.rpc.status_pb2.Status): + Output only. Error status if File processing + failed. + """ + + class State(proto.Enum): + r"""States for the lifecycle of a File. + + Values: + STATE_UNSPECIFIED (0): + The default value. This value is used if the + state is omitted. + PROCESSING (1): + File is being processed and cannot be used + for inference yet. + ACTIVE (2): + File is processed and available for + inference. + FAILED (10): + File failed processing. + """ + STATE_UNSPECIFIED = 0 + PROCESSING = 1 + ACTIVE = 2 + FAILED = 10 + + video_metadata: "VideoMetadata" = proto.Field( + proto.MESSAGE, + number=12, + oneof="metadata", + message="VideoMetadata", + ) + name: str = proto.Field( + proto.STRING, + number=1, + ) + display_name: str = proto.Field( + proto.STRING, + number=2, + ) + mime_type: str = proto.Field( + proto.STRING, + number=3, + ) + size_bytes: int = proto.Field( + proto.INT64, + number=4, + ) + create_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=5, + message=timestamp_pb2.Timestamp, + ) + update_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=6, + message=timestamp_pb2.Timestamp, + ) + expiration_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=7, + message=timestamp_pb2.Timestamp, + ) + sha256_hash: bytes = proto.Field( + proto.BYTES, + number=8, + ) + uri: str = proto.Field( + proto.STRING, + number=9, + ) + state: State = proto.Field( + proto.ENUM, + number=10, + enum=State, + ) + error: status_pb2.Status = proto.Field( + proto.MESSAGE, + number=11, + message=status_pb2.Status, + ) + + +class VideoMetadata(proto.Message): + r"""Metadata for a video ``File``. + + Attributes: + video_duration (google.protobuf.duration_pb2.Duration): + Duration of the video. + """ + + video_duration: duration_pb2.Duration = proto.Field( + proto.MESSAGE, + number=1, + message=duration_pb2.Duration, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/file_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/file_service.py new file mode 100644 index 000000000000..6398c075f2d3 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/file_service.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import file as gag_file + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "CreateFileRequest", + "CreateFileResponse", + "ListFilesRequest", + "ListFilesResponse", + "GetFileRequest", + "DeleteFileRequest", + }, +) + + +class CreateFileRequest(proto.Message): + r"""Request for ``CreateFile``. + + Attributes: + file (google.ai.generativelanguage_v1alpha.types.File): + Optional. Metadata for the file to create. + """ + + file: gag_file.File = proto.Field( + proto.MESSAGE, + number=1, + message=gag_file.File, + ) + + +class CreateFileResponse(proto.Message): + r"""Response for ``CreateFile``. + + Attributes: + file (google.ai.generativelanguage_v1alpha.types.File): + Metadata for the created file. + """ + + file: gag_file.File = proto.Field( + proto.MESSAGE, + number=1, + message=gag_file.File, + ) + + +class ListFilesRequest(proto.Message): + r"""Request for ``ListFiles``. + + Attributes: + page_size (int): + Optional. Maximum number of ``File``\ s to return per page. + If unspecified, defaults to 10. Maximum ``page_size`` is + 100. + page_token (str): + Optional. A page token from a previous ``ListFiles`` call. + """ + + page_size: int = proto.Field( + proto.INT32, + number=1, + ) + page_token: str = proto.Field( + proto.STRING, + number=3, + ) + + +class ListFilesResponse(proto.Message): + r"""Response for ``ListFiles``. + + Attributes: + files (MutableSequence[google.ai.generativelanguage_v1alpha.types.File]): + The list of ``File``\ s. + next_page_token (str): + A token that can be sent as a ``page_token`` into a + subsequent ``ListFiles`` call. + """ + + @property + def raw_page(self): + return self + + files: MutableSequence[gag_file.File] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=gag_file.File, + ) + next_page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +class GetFileRequest(proto.Message): + r"""Request for ``GetFile``. + + Attributes: + name (str): + Required. The name of the ``File`` to get. Example: + ``files/abc-123`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class DeleteFileRequest(proto.Message): + r"""Request for ``DeleteFile``. + + Attributes: + name (str): + Required. The name of the ``File`` to delete. Example: + ``files/abc-123`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/generative_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/generative_service.py new file mode 100644 index 000000000000..c1088ab1682c --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/generative_service.py @@ -0,0 +1,2139 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import citation +from google.ai.generativelanguage_v1alpha.types import content as gag_content +from google.ai.generativelanguage_v1alpha.types import retriever, safety + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "TaskType", + "GenerateContentRequest", + "PrebuiltVoiceConfig", + "VoiceConfig", + "SpeechConfig", + "GenerationConfig", + "SemanticRetrieverConfig", + "GenerateContentResponse", + "Candidate", + "LogprobsResult", + "AttributionSourceId", + "GroundingAttribution", + "RetrievalMetadata", + "GroundingMetadata", + "SearchEntryPoint", + "GroundingChunk", + "Segment", + "GroundingSupport", + "GenerateAnswerRequest", + "GenerateAnswerResponse", + "EmbedContentRequest", + "ContentEmbedding", + "EmbedContentResponse", + "BatchEmbedContentsRequest", + "BatchEmbedContentsResponse", + "CountTokensRequest", + "CountTokensResponse", + "BidiGenerateContentSetup", + "BidiGenerateContentClientContent", + "BidiGenerateContentRealtimeInput", + "BidiGenerateContentToolResponse", + "BidiGenerateContentClientMessage", + "BidiGenerateContentSetupComplete", + "BidiGenerateContentServerContent", + "BidiGenerateContentToolCall", + "BidiGenerateContentToolCallCancellation", + "BidiGenerateContentServerMessage", + }, +) + + +class TaskType(proto.Enum): + r"""Type of task for which the embedding will be used. + + Values: + TASK_TYPE_UNSPECIFIED (0): + Unset value, which will default to one of the + other enum values. + RETRIEVAL_QUERY (1): + Specifies the given text is a query in a + search/retrieval setting. + RETRIEVAL_DOCUMENT (2): + Specifies the given text is a document from + the corpus being searched. + SEMANTIC_SIMILARITY (3): + Specifies the given text will be used for + STS. + CLASSIFICATION (4): + Specifies that the given text will be + classified. + CLUSTERING (5): + Specifies that the embeddings will be used + for clustering. + QUESTION_ANSWERING (6): + Specifies that the given text will be used + for question answering. + FACT_VERIFICATION (7): + Specifies that the given text will be used + for fact verification. + """ + TASK_TYPE_UNSPECIFIED = 0 + RETRIEVAL_QUERY = 1 + RETRIEVAL_DOCUMENT = 2 + SEMANTIC_SIMILARITY = 3 + CLASSIFICATION = 4 + CLUSTERING = 5 + QUESTION_ANSWERING = 6 + FACT_VERIFICATION = 7 + + +class GenerateContentRequest(proto.Message): + r"""Request to generate a completion from the model. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + model (str): + Required. The name of the ``Model`` to use for generating + the completion. + + Format: ``models/{model}``. + system_instruction (google.ai.generativelanguage_v1alpha.types.Content): + Optional. Developer set `system + instruction(s) `__. + Currently, text only. + + This field is a member of `oneof`_ ``_system_instruction``. + contents (MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]): + Required. The content of the current conversation with the + model. + + For single-turn queries, this is a single instance. For + multi-turn queries like + `chat `__, + this is a repeated field that contains the conversation + history and the latest request. + tools (MutableSequence[google.ai.generativelanguage_v1alpha.types.Tool]): + Optional. A list of ``Tools`` the ``Model`` may use to + generate the next response. + + A ``Tool`` is a piece of code that enables the system to + interact with external systems to perform an action, or set + of actions, outside of knowledge and scope of the ``Model``. + Supported ``Tool``\ s are ``Function`` and + ``code_execution``. Refer to the `Function + calling `__ + and the `Code + execution `__ + guides to learn more. + tool_config (google.ai.generativelanguage_v1alpha.types.ToolConfig): + Optional. Tool configuration for any ``Tool`` specified in + the request. Refer to the `Function calling + guide `__ + for a usage example. + safety_settings (MutableSequence[google.ai.generativelanguage_v1alpha.types.SafetySetting]): + Optional. A list of unique ``SafetySetting`` instances for + blocking unsafe content. + + This will be enforced on the + ``GenerateContentRequest.contents`` and + ``GenerateContentResponse.candidates``. There should not be + more than one setting for each ``SafetyCategory`` type. The + API will block any contents and responses that fail to meet + the thresholds set by these settings. This list overrides + the default settings for each ``SafetyCategory`` specified + in the safety_settings. If there is no ``SafetySetting`` for + a given ``SafetyCategory`` provided in the list, the API + will use the default safety setting for that category. Harm + categories HARM_CATEGORY_HATE_SPEECH, + HARM_CATEGORY_SEXUALLY_EXPLICIT, + HARM_CATEGORY_DANGEROUS_CONTENT, HARM_CATEGORY_HARASSMENT, + HARM_CATEGORY_CIVIC_INTEGRITY are supported. Refer to the + `guide `__ + for detailed information on available safety settings. Also + refer to the `Safety + guidance `__ + to learn how to incorporate safety considerations in your AI + applications. + generation_config (google.ai.generativelanguage_v1alpha.types.GenerationConfig): + Optional. Configuration options for model + generation and outputs. + + This field is a member of `oneof`_ ``_generation_config``. + cached_content (str): + Optional. The name of the content + `cached `__ + to use as context to serve the prediction. Format: + ``cachedContents/{cachedContent}`` + + This field is a member of `oneof`_ ``_cached_content``. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + system_instruction: gag_content.Content = proto.Field( + proto.MESSAGE, + number=8, + optional=True, + message=gag_content.Content, + ) + contents: MutableSequence[gag_content.Content] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=gag_content.Content, + ) + tools: MutableSequence[gag_content.Tool] = proto.RepeatedField( + proto.MESSAGE, + number=5, + message=gag_content.Tool, + ) + tool_config: gag_content.ToolConfig = proto.Field( + proto.MESSAGE, + number=7, + message=gag_content.ToolConfig, + ) + safety_settings: MutableSequence[safety.SafetySetting] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message=safety.SafetySetting, + ) + generation_config: "GenerationConfig" = proto.Field( + proto.MESSAGE, + number=4, + optional=True, + message="GenerationConfig", + ) + cached_content: str = proto.Field( + proto.STRING, + number=9, + optional=True, + ) + + +class PrebuiltVoiceConfig(proto.Message): + r"""The configuration for the prebuilt speaker to use. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + voice_name (str): + The name of the preset voice to use. + + This field is a member of `oneof`_ ``_voice_name``. + """ + + voice_name: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + + +class VoiceConfig(proto.Message): + r"""The configuration for the voice to use. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prebuilt_voice_config (google.ai.generativelanguage_v1alpha.types.PrebuiltVoiceConfig): + The configuration for the prebuilt voice to + use. + + This field is a member of `oneof`_ ``voice_config``. + """ + + prebuilt_voice_config: "PrebuiltVoiceConfig" = proto.Field( + proto.MESSAGE, + number=1, + oneof="voice_config", + message="PrebuiltVoiceConfig", + ) + + +class SpeechConfig(proto.Message): + r"""The speech generation config. + + Attributes: + voice_config (google.ai.generativelanguage_v1alpha.types.VoiceConfig): + The configuration for the speaker to use. + """ + + voice_config: "VoiceConfig" = proto.Field( + proto.MESSAGE, + number=1, + message="VoiceConfig", + ) + + +class GenerationConfig(proto.Message): + r"""Configuration options for model generation and outputs. Not + all parameters are configurable for every model. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + candidate_count (int): + Optional. Number of generated responses to + return. + Currently, this value can only be set to 1. If + unset, this will default to 1. + + This field is a member of `oneof`_ ``_candidate_count``. + stop_sequences (MutableSequence[str]): + Optional. The set of character sequences (up to 5) that will + stop output generation. If specified, the API will stop at + the first appearance of a ``stop_sequence``. The stop + sequence will not be included as part of the response. + max_output_tokens (int): + Optional. The maximum number of tokens to include in a + response candidate. + + Note: The default value varies by model, see the + ``Model.output_token_limit`` attribute of the ``Model`` + returned from the ``getModel`` function. + + This field is a member of `oneof`_ ``_max_output_tokens``. + temperature (float): + Optional. Controls the randomness of the output. + + Note: The default value varies by model, see the + ``Model.temperature`` attribute of the ``Model`` returned + from the ``getModel`` function. + + Values can range from [0.0, 2.0]. + + This field is a member of `oneof`_ ``_temperature``. + top_p (float): + Optional. The maximum cumulative probability of tokens to + consider when sampling. + + The model uses combined Top-k and Top-p (nucleus) sampling. + + Tokens are sorted based on their assigned probabilities so + that only the most likely tokens are considered. Top-k + sampling directly limits the maximum number of tokens to + consider, while Nucleus sampling limits the number of tokens + based on the cumulative probability. + + Note: The default value varies by ``Model`` and is specified + by the\ ``Model.top_p`` attribute returned from the + ``getModel`` function. An empty ``top_k`` attribute + indicates that the model doesn't apply top-k sampling and + doesn't allow setting ``top_k`` on requests. + + This field is a member of `oneof`_ ``_top_p``. + top_k (int): + Optional. The maximum number of tokens to consider when + sampling. + + Gemini models use Top-p (nucleus) sampling or a combination + of Top-k and nucleus sampling. Top-k sampling considers the + set of ``top_k`` most probable tokens. Models running with + nucleus sampling don't allow top_k setting. + + Note: The default value varies by ``Model`` and is specified + by the\ ``Model.top_p`` attribute returned from the + ``getModel`` function. An empty ``top_k`` attribute + indicates that the model doesn't apply top-k sampling and + doesn't allow setting ``top_k`` on requests. + + This field is a member of `oneof`_ ``_top_k``. + response_mime_type (str): + Optional. MIME type of the generated candidate text. + Supported MIME types are: ``text/plain``: (default) Text + output. ``application/json``: JSON response in the response + candidates. ``text/x.enum``: ENUM as a string response in + the response candidates. Refer to the + `docs `__ + for a list of all supported text MIME types. + response_schema (google.ai.generativelanguage_v1alpha.types.Schema): + Optional. Output schema of the generated candidate text. + Schemas must be a subset of the `OpenAPI + schema `__ and + can be objects, primitives or arrays. + + If set, a compatible ``response_mime_type`` must also be + set. Compatible MIME types: ``application/json``: Schema for + JSON response. Refer to the `JSON text generation + guide `__ + for more details. + presence_penalty (float): + Optional. Presence penalty applied to the next token's + logprobs if the token has already been seen in the response. + + This penalty is binary on/off and not dependant on the + number of times the token is used (after the first). Use + [frequency_penalty][google.ai.generativelanguage.v1alpha.GenerationConfig.frequency_penalty] + for a penalty that increases with each use. + + A positive penalty will discourage the use of tokens that + have already been used in the response, increasing the + vocabulary. + + A negative penalty will encourage the use of tokens that + have already been used in the response, decreasing the + vocabulary. + + This field is a member of `oneof`_ ``_presence_penalty``. + frequency_penalty (float): + Optional. Frequency penalty applied to the next token's + logprobs, multiplied by the number of times each token has + been seen in the respponse so far. + + A positive penalty will discourage the use of tokens that + have already been used, proportional to the number of times + the token has been used: The more a token is used, the more + dificult it is for the model to use that token again + increasing the vocabulary of responses. + + Caution: A *negative* penalty will encourage the model to + reuse tokens proportional to the number of times the token + has been used. Small negative values will reduce the + vocabulary of a response. Larger negative values will cause + the model to start repeating a common token until it hits + the + [max_output_tokens][google.ai.generativelanguage.v1alpha.GenerationConfig.max_output_tokens] + limit. + + This field is a member of `oneof`_ ``_frequency_penalty``. + response_logprobs (bool): + Optional. If true, export the logprobs + results in response. + + This field is a member of `oneof`_ ``_response_logprobs``. + logprobs (int): + Optional. Only valid if + [response_logprobs=True][google.ai.generativelanguage.v1alpha.GenerationConfig.response_logprobs]. + This sets the number of top logprobs to return at each + decoding step in the + [Candidate.logprobs_result][google.ai.generativelanguage.v1alpha.Candidate.logprobs_result]. + + This field is a member of `oneof`_ ``_logprobs``. + enable_enhanced_civic_answers (bool): + Optional. Enables enhanced civic answers. It + may not be available for all models. + + This field is a member of `oneof`_ ``_enable_enhanced_civic_answers``. + response_modalities (MutableSequence[google.ai.generativelanguage_v1alpha.types.GenerationConfig.Modality]): + Optional. The requested modalities of the + response. Represents the set of modalities that + the model can return, and should be expected in + the response. This is an exact match to the + modalities of the response. + + A model may have multiple combinations of + supported modalities. If the requested + modalities do not match any of the supported + combinations, an error will be returned. + + An empty list is equivalent to requesting only + text. + speech_config (google.ai.generativelanguage_v1alpha.types.SpeechConfig): + Optional. The speech generation config. + + This field is a member of `oneof`_ ``_speech_config``. + """ + + class Modality(proto.Enum): + r"""Supported modalities of the response. + + Values: + MODALITY_UNSPECIFIED (0): + Default value. + TEXT (1): + Indicates the model should return text. + IMAGE (2): + Indicates the model should return images. + AUDIO (3): + Indicates the model should return audio. + """ + MODALITY_UNSPECIFIED = 0 + TEXT = 1 + IMAGE = 2 + AUDIO = 3 + + candidate_count: int = proto.Field( + proto.INT32, + number=1, + optional=True, + ) + stop_sequences: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=2, + ) + max_output_tokens: int = proto.Field( + proto.INT32, + number=4, + optional=True, + ) + temperature: float = proto.Field( + proto.FLOAT, + number=5, + optional=True, + ) + top_p: float = proto.Field( + proto.FLOAT, + number=6, + optional=True, + ) + top_k: int = proto.Field( + proto.INT32, + number=7, + optional=True, + ) + response_mime_type: str = proto.Field( + proto.STRING, + number=13, + ) + response_schema: gag_content.Schema = proto.Field( + proto.MESSAGE, + number=14, + message=gag_content.Schema, + ) + presence_penalty: float = proto.Field( + proto.FLOAT, + number=15, + optional=True, + ) + frequency_penalty: float = proto.Field( + proto.FLOAT, + number=16, + optional=True, + ) + response_logprobs: bool = proto.Field( + proto.BOOL, + number=17, + optional=True, + ) + logprobs: int = proto.Field( + proto.INT32, + number=18, + optional=True, + ) + enable_enhanced_civic_answers: bool = proto.Field( + proto.BOOL, + number=19, + optional=True, + ) + response_modalities: MutableSequence[Modality] = proto.RepeatedField( + proto.ENUM, + number=20, + enum=Modality, + ) + speech_config: "SpeechConfig" = proto.Field( + proto.MESSAGE, + number=21, + optional=True, + message="SpeechConfig", + ) + + +class SemanticRetrieverConfig(proto.Message): + r"""Configuration for retrieving grounding content from a ``Corpus`` or + ``Document`` created using the Semantic Retriever API. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + source (str): + Required. Name of the resource for retrieval. Example: + ``corpora/123`` or ``corpora/123/documents/abc``. + query (google.ai.generativelanguage_v1alpha.types.Content): + Required. Query to use for matching ``Chunk``\ s in the + given resource by similarity. + metadata_filters (MutableSequence[google.ai.generativelanguage_v1alpha.types.MetadataFilter]): + Optional. Filters for selecting ``Document``\ s and/or + ``Chunk``\ s from the resource. + max_chunks_count (int): + Optional. Maximum number of relevant ``Chunk``\ s to + retrieve. + + This field is a member of `oneof`_ ``_max_chunks_count``. + minimum_relevance_score (float): + Optional. Minimum relevance score for retrieved relevant + ``Chunk``\ s. + + This field is a member of `oneof`_ ``_minimum_relevance_score``. + """ + + source: str = proto.Field( + proto.STRING, + number=1, + ) + query: gag_content.Content = proto.Field( + proto.MESSAGE, + number=2, + message=gag_content.Content, + ) + metadata_filters: MutableSequence[retriever.MetadataFilter] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message=retriever.MetadataFilter, + ) + max_chunks_count: int = proto.Field( + proto.INT32, + number=4, + optional=True, + ) + minimum_relevance_score: float = proto.Field( + proto.FLOAT, + number=5, + optional=True, + ) + + +class GenerateContentResponse(proto.Message): + r"""Response from the model supporting multiple candidate responses. + + Safety ratings and content filtering are reported for both prompt in + ``GenerateContentResponse.prompt_feedback`` and for each candidate + in ``finish_reason`` and in ``safety_ratings``. The API: + + - Returns either all requested candidates or none of them + - Returns no candidates at all only if there was something wrong + with the prompt (check ``prompt_feedback``) + - Reports feedback on each candidate in ``finish_reason`` and + ``safety_ratings``. + + Attributes: + candidates (MutableSequence[google.ai.generativelanguage_v1alpha.types.Candidate]): + Candidate responses from the model. + prompt_feedback (google.ai.generativelanguage_v1alpha.types.GenerateContentResponse.PromptFeedback): + Returns the prompt's feedback related to the + content filters. + usage_metadata (google.ai.generativelanguage_v1alpha.types.GenerateContentResponse.UsageMetadata): + Output only. Metadata on the generation + requests' token usage. + model_version (str): + Output only. The model version used to + generate the response. + """ + + class PromptFeedback(proto.Message): + r"""A set of the feedback metadata the prompt specified in + ``GenerateContentRequest.content``. + + Attributes: + block_reason (google.ai.generativelanguage_v1alpha.types.GenerateContentResponse.PromptFeedback.BlockReason): + Optional. If set, the prompt was blocked and + no candidates are returned. Rephrase the prompt. + safety_ratings (MutableSequence[google.ai.generativelanguage_v1alpha.types.SafetyRating]): + Ratings for safety of the prompt. + There is at most one rating per category. + """ + + class BlockReason(proto.Enum): + r"""Specifies the reason why the prompt was blocked. + + Values: + BLOCK_REASON_UNSPECIFIED (0): + Default value. This value is unused. + SAFETY (1): + Prompt was blocked due to safety reasons. Inspect + ``safety_ratings`` to understand which safety category + blocked it. + OTHER (2): + Prompt was blocked due to unknown reasons. + BLOCKLIST (3): + Prompt was blocked due to the terms which are + included from the terminology blocklist. + PROHIBITED_CONTENT (4): + Prompt was blocked due to prohibited content. + IMAGE_SAFETY (5): + Candidates blocked due to unsafe image + generation content. + """ + BLOCK_REASON_UNSPECIFIED = 0 + SAFETY = 1 + OTHER = 2 + BLOCKLIST = 3 + PROHIBITED_CONTENT = 4 + IMAGE_SAFETY = 5 + + block_reason: "GenerateContentResponse.PromptFeedback.BlockReason" = ( + proto.Field( + proto.ENUM, + number=1, + enum="GenerateContentResponse.PromptFeedback.BlockReason", + ) + ) + safety_ratings: MutableSequence[safety.SafetyRating] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=safety.SafetyRating, + ) + + class UsageMetadata(proto.Message): + r"""Metadata on the generation request's token usage. + + Attributes: + prompt_token_count (int): + Number of tokens in the prompt. When ``cached_content`` is + set, this is still the total effective prompt size meaning + this includes the number of tokens in the cached content. + cached_content_token_count (int): + Number of tokens in the cached part of the + prompt (the cached content) + candidates_token_count (int): + Total number of tokens across all the + generated response candidates. + total_token_count (int): + Total token count for the generation request + (prompt + response candidates). + """ + + prompt_token_count: int = proto.Field( + proto.INT32, + number=1, + ) + cached_content_token_count: int = proto.Field( + proto.INT32, + number=4, + ) + candidates_token_count: int = proto.Field( + proto.INT32, + number=2, + ) + total_token_count: int = proto.Field( + proto.INT32, + number=3, + ) + + candidates: MutableSequence["Candidate"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="Candidate", + ) + prompt_feedback: PromptFeedback = proto.Field( + proto.MESSAGE, + number=2, + message=PromptFeedback, + ) + usage_metadata: UsageMetadata = proto.Field( + proto.MESSAGE, + number=3, + message=UsageMetadata, + ) + model_version: str = proto.Field( + proto.STRING, + number=4, + ) + + +class Candidate(proto.Message): + r"""A response candidate generated from the model. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + index (int): + Output only. Index of the candidate in the + list of response candidates. + + This field is a member of `oneof`_ ``_index``. + content (google.ai.generativelanguage_v1alpha.types.Content): + Output only. Generated content returned from + the model. + finish_reason (google.ai.generativelanguage_v1alpha.types.Candidate.FinishReason): + Optional. Output only. The reason why the + model stopped generating tokens. + If empty, the model has not stopped generating + tokens. + safety_ratings (MutableSequence[google.ai.generativelanguage_v1alpha.types.SafetyRating]): + List of ratings for the safety of a response + candidate. + There is at most one rating per category. + citation_metadata (google.ai.generativelanguage_v1alpha.types.CitationMetadata): + Output only. Citation information for model-generated + candidate. + + This field may be populated with recitation information for + any text included in the ``content``. These are passages + that are "recited" from copyrighted material in the + foundational LLM's training data. + token_count (int): + Output only. Token count for this candidate. + grounding_attributions (MutableSequence[google.ai.generativelanguage_v1alpha.types.GroundingAttribution]): + Output only. Attribution information for sources that + contributed to a grounded answer. + + This field is populated for ``GenerateAnswer`` calls. + grounding_metadata (google.ai.generativelanguage_v1alpha.types.GroundingMetadata): + Output only. Grounding metadata for the candidate. + + This field is populated for ``GenerateContent`` calls. + avg_logprobs (float): + Output only. Average log probability score of + the candidate. + logprobs_result (google.ai.generativelanguage_v1alpha.types.LogprobsResult): + Output only. Log-likelihood scores for the + response tokens and top tokens + """ + + class FinishReason(proto.Enum): + r"""Defines the reason why the model stopped generating tokens. + + Values: + FINISH_REASON_UNSPECIFIED (0): + Default value. This value is unused. + STOP (1): + Natural stop point of the model or provided + stop sequence. + MAX_TOKENS (2): + The maximum number of tokens as specified in + the request was reached. + SAFETY (3): + The response candidate content was flagged + for safety reasons. + RECITATION (4): + The response candidate content was flagged + for recitation reasons. + LANGUAGE (6): + The response candidate content was flagged + for using an unsupported language. + OTHER (5): + Unknown reason. + BLOCKLIST (7): + Token generation stopped because the content + contains forbidden terms. + PROHIBITED_CONTENT (8): + Token generation stopped for potentially + containing prohibited content. + SPII (9): + Token generation stopped because the content + potentially contains Sensitive Personally + Identifiable Information (SPII). + MALFORMED_FUNCTION_CALL (10): + The function call generated by the model is + invalid. + IMAGE_SAFETY (11): + Token generation stopped because generated + images contain safety violations. + """ + FINISH_REASON_UNSPECIFIED = 0 + STOP = 1 + MAX_TOKENS = 2 + SAFETY = 3 + RECITATION = 4 + LANGUAGE = 6 + OTHER = 5 + BLOCKLIST = 7 + PROHIBITED_CONTENT = 8 + SPII = 9 + MALFORMED_FUNCTION_CALL = 10 + IMAGE_SAFETY = 11 + + index: int = proto.Field( + proto.INT32, + number=3, + optional=True, + ) + content: gag_content.Content = proto.Field( + proto.MESSAGE, + number=1, + message=gag_content.Content, + ) + finish_reason: FinishReason = proto.Field( + proto.ENUM, + number=2, + enum=FinishReason, + ) + safety_ratings: MutableSequence[safety.SafetyRating] = proto.RepeatedField( + proto.MESSAGE, + number=5, + message=safety.SafetyRating, + ) + citation_metadata: citation.CitationMetadata = proto.Field( + proto.MESSAGE, + number=6, + message=citation.CitationMetadata, + ) + token_count: int = proto.Field( + proto.INT32, + number=7, + ) + grounding_attributions: MutableSequence[ + "GroundingAttribution" + ] = proto.RepeatedField( + proto.MESSAGE, + number=8, + message="GroundingAttribution", + ) + grounding_metadata: "GroundingMetadata" = proto.Field( + proto.MESSAGE, + number=9, + message="GroundingMetadata", + ) + avg_logprobs: float = proto.Field( + proto.DOUBLE, + number=10, + ) + logprobs_result: "LogprobsResult" = proto.Field( + proto.MESSAGE, + number=11, + message="LogprobsResult", + ) + + +class LogprobsResult(proto.Message): + r"""Logprobs Result + + Attributes: + top_candidates (MutableSequence[google.ai.generativelanguage_v1alpha.types.LogprobsResult.TopCandidates]): + Length = total number of decoding steps. + chosen_candidates (MutableSequence[google.ai.generativelanguage_v1alpha.types.LogprobsResult.Candidate]): + Length = total number of decoding steps. The chosen + candidates may or may not be in top_candidates. + """ + + class Candidate(proto.Message): + r"""Candidate for the logprobs token and score. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + token (str): + The candidate’s token string value. + + This field is a member of `oneof`_ ``_token``. + token_id (int): + The candidate’s token id value. + + This field is a member of `oneof`_ ``_token_id``. + log_probability (float): + The candidate's log probability. + + This field is a member of `oneof`_ ``_log_probability``. + """ + + token: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + token_id: int = proto.Field( + proto.INT32, + number=3, + optional=True, + ) + log_probability: float = proto.Field( + proto.FLOAT, + number=2, + optional=True, + ) + + class TopCandidates(proto.Message): + r"""Candidates with top log probabilities at each decoding step. + + Attributes: + candidates (MutableSequence[google.ai.generativelanguage_v1alpha.types.LogprobsResult.Candidate]): + Sorted by log probability in descending + order. + """ + + candidates: MutableSequence["LogprobsResult.Candidate"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="LogprobsResult.Candidate", + ) + + top_candidates: MutableSequence[TopCandidates] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=TopCandidates, + ) + chosen_candidates: MutableSequence[Candidate] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=Candidate, + ) + + +class AttributionSourceId(proto.Message): + r"""Identifier for the source contributing to this attribution. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + grounding_passage (google.ai.generativelanguage_v1alpha.types.AttributionSourceId.GroundingPassageId): + Identifier for an inline passage. + + This field is a member of `oneof`_ ``source``. + semantic_retriever_chunk (google.ai.generativelanguage_v1alpha.types.AttributionSourceId.SemanticRetrieverChunk): + Identifier for a ``Chunk`` fetched via Semantic Retriever. + + This field is a member of `oneof`_ ``source``. + """ + + class GroundingPassageId(proto.Message): + r"""Identifier for a part within a ``GroundingPassage``. + + Attributes: + passage_id (str): + Output only. ID of the passage matching the + ``GenerateAnswerRequest``'s ``GroundingPassage.id``. + part_index (int): + Output only. Index of the part within the + ``GenerateAnswerRequest``'s ``GroundingPassage.content``. + """ + + passage_id: str = proto.Field( + proto.STRING, + number=1, + ) + part_index: int = proto.Field( + proto.INT32, + number=2, + ) + + class SemanticRetrieverChunk(proto.Message): + r"""Identifier for a ``Chunk`` retrieved via Semantic Retriever + specified in the ``GenerateAnswerRequest`` using + ``SemanticRetrieverConfig``. + + Attributes: + source (str): + Output only. Name of the source matching the request's + ``SemanticRetrieverConfig.source``. Example: ``corpora/123`` + or ``corpora/123/documents/abc`` + chunk (str): + Output only. Name of the ``Chunk`` containing the attributed + text. Example: ``corpora/123/documents/abc/chunks/xyz`` + """ + + source: str = proto.Field( + proto.STRING, + number=1, + ) + chunk: str = proto.Field( + proto.STRING, + number=2, + ) + + grounding_passage: GroundingPassageId = proto.Field( + proto.MESSAGE, + number=1, + oneof="source", + message=GroundingPassageId, + ) + semantic_retriever_chunk: SemanticRetrieverChunk = proto.Field( + proto.MESSAGE, + number=2, + oneof="source", + message=SemanticRetrieverChunk, + ) + + +class GroundingAttribution(proto.Message): + r"""Attribution for a source that contributed to an answer. + + Attributes: + source_id (google.ai.generativelanguage_v1alpha.types.AttributionSourceId): + Output only. Identifier for the source + contributing to this attribution. + content (google.ai.generativelanguage_v1alpha.types.Content): + Grounding source content that makes up this + attribution. + """ + + source_id: "AttributionSourceId" = proto.Field( + proto.MESSAGE, + number=3, + message="AttributionSourceId", + ) + content: gag_content.Content = proto.Field( + proto.MESSAGE, + number=2, + message=gag_content.Content, + ) + + +class RetrievalMetadata(proto.Message): + r"""Metadata related to retrieval in the grounding flow. + + Attributes: + google_search_dynamic_retrieval_score (float): + Optional. Score indicating how likely information from + google search could help answer the prompt. The score is in + the range [0, 1], where 0 is the least likely and 1 is the + most likely. This score is only populated when google search + grounding and dynamic retrieval is enabled. It will be + compared to the threshold to determine whether to trigger + google search. + """ + + google_search_dynamic_retrieval_score: float = proto.Field( + proto.FLOAT, + number=2, + ) + + +class GroundingMetadata(proto.Message): + r"""Metadata returned to client when grounding is enabled. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + search_entry_point (google.ai.generativelanguage_v1alpha.types.SearchEntryPoint): + Optional. Google search entry for the + following-up web searches. + + This field is a member of `oneof`_ ``_search_entry_point``. + grounding_chunks (MutableSequence[google.ai.generativelanguage_v1alpha.types.GroundingChunk]): + List of supporting references retrieved from + specified grounding source. + grounding_supports (MutableSequence[google.ai.generativelanguage_v1alpha.types.GroundingSupport]): + List of grounding support. + retrieval_metadata (google.ai.generativelanguage_v1alpha.types.RetrievalMetadata): + Metadata related to retrieval in the + grounding flow. + + This field is a member of `oneof`_ ``_retrieval_metadata``. + web_search_queries (MutableSequence[str]): + Web search queries for the following-up web + search. + """ + + search_entry_point: "SearchEntryPoint" = proto.Field( + proto.MESSAGE, + number=1, + optional=True, + message="SearchEntryPoint", + ) + grounding_chunks: MutableSequence["GroundingChunk"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="GroundingChunk", + ) + grounding_supports: MutableSequence["GroundingSupport"] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="GroundingSupport", + ) + retrieval_metadata: "RetrievalMetadata" = proto.Field( + proto.MESSAGE, + number=4, + optional=True, + message="RetrievalMetadata", + ) + web_search_queries: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=5, + ) + + +class SearchEntryPoint(proto.Message): + r"""Google search entry point. + + Attributes: + rendered_content (str): + Optional. Web content snippet that can be + embedded in a web page or an app webview. + sdk_blob (bytes): + Optional. Base64 encoded JSON representing + array of tuple. + """ + + rendered_content: str = proto.Field( + proto.STRING, + number=1, + ) + sdk_blob: bytes = proto.Field( + proto.BYTES, + number=2, + ) + + +class GroundingChunk(proto.Message): + r"""Grounding chunk. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + web (google.ai.generativelanguage_v1alpha.types.GroundingChunk.Web): + Grounding chunk from the web. + + This field is a member of `oneof`_ ``chunk_type``. + """ + + class Web(proto.Message): + r"""Chunk from the web. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + uri (str): + URI reference of the chunk. + + This field is a member of `oneof`_ ``_uri``. + title (str): + Title of the chunk. + + This field is a member of `oneof`_ ``_title``. + """ + + uri: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + title: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + + web: Web = proto.Field( + proto.MESSAGE, + number=1, + oneof="chunk_type", + message=Web, + ) + + +class Segment(proto.Message): + r"""Segment of the content. + + Attributes: + part_index (int): + Output only. The index of a Part object + within its parent Content object. + start_index (int): + Output only. Start index in the given Part, + measured in bytes. Offset from the start of the + Part, inclusive, starting at zero. + end_index (int): + Output only. End index in the given Part, + measured in bytes. Offset from the start of the + Part, exclusive, starting at zero. + text (str): + Output only. The text corresponding to the + segment from the response. + """ + + part_index: int = proto.Field( + proto.INT32, + number=1, + ) + start_index: int = proto.Field( + proto.INT32, + number=2, + ) + end_index: int = proto.Field( + proto.INT32, + number=3, + ) + text: str = proto.Field( + proto.STRING, + number=4, + ) + + +class GroundingSupport(proto.Message): + r"""Grounding support. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + segment (google.ai.generativelanguage_v1alpha.types.Segment): + Segment of the content this support belongs + to. + + This field is a member of `oneof`_ ``_segment``. + grounding_chunk_indices (MutableSequence[int]): + A list of indices (into 'grounding_chunk') specifying the + citations associated with the claim. For instance [1,3,4] + means that grounding_chunk[1], grounding_chunk[3], + grounding_chunk[4] are the retrieved content attributed to + the claim. + confidence_scores (MutableSequence[float]): + Confidence score of the support references. Ranges from 0 to + 1. 1 is the most confident. This list must have the same + size as the grounding_chunk_indices. + """ + + segment: "Segment" = proto.Field( + proto.MESSAGE, + number=1, + optional=True, + message="Segment", + ) + grounding_chunk_indices: MutableSequence[int] = proto.RepeatedField( + proto.INT32, + number=2, + ) + confidence_scores: MutableSequence[float] = proto.RepeatedField( + proto.FLOAT, + number=3, + ) + + +class GenerateAnswerRequest(proto.Message): + r"""Request to generate a grounded answer from the ``Model``. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + inline_passages (google.ai.generativelanguage_v1alpha.types.GroundingPassages): + Passages provided inline with the request. + + This field is a member of `oneof`_ ``grounding_source``. + semantic_retriever (google.ai.generativelanguage_v1alpha.types.SemanticRetrieverConfig): + Content retrieved from resources created via + the Semantic Retriever API. + + This field is a member of `oneof`_ ``grounding_source``. + model (str): + Required. The name of the ``Model`` to use for generating + the grounded response. + + Format: ``model=models/{model}``. + contents (MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]): + Required. The content of the current conversation with the + ``Model``. For single-turn queries, this is a single + question to answer. For multi-turn queries, this is a + repeated field that contains conversation history and the + last ``Content`` in the list containing the question. + + Note: ``GenerateAnswer`` only supports queries in English. + answer_style (google.ai.generativelanguage_v1alpha.types.GenerateAnswerRequest.AnswerStyle): + Required. Style in which answers should be + returned. + safety_settings (MutableSequence[google.ai.generativelanguage_v1alpha.types.SafetySetting]): + Optional. A list of unique ``SafetySetting`` instances for + blocking unsafe content. + + This will be enforced on the + ``GenerateAnswerRequest.contents`` and + ``GenerateAnswerResponse.candidate``. There should not be + more than one setting for each ``SafetyCategory`` type. The + API will block any contents and responses that fail to meet + the thresholds set by these settings. This list overrides + the default settings for each ``SafetyCategory`` specified + in the safety_settings. If there is no ``SafetySetting`` for + a given ``SafetyCategory`` provided in the list, the API + will use the default safety setting for that category. Harm + categories HARM_CATEGORY_HATE_SPEECH, + HARM_CATEGORY_SEXUALLY_EXPLICIT, + HARM_CATEGORY_DANGEROUS_CONTENT, HARM_CATEGORY_HARASSMENT + are supported. Refer to the + `guide `__ + for detailed information on available safety settings. Also + refer to the `Safety + guidance `__ + to learn how to incorporate safety considerations in your AI + applications. + temperature (float): + Optional. Controls the randomness of the output. + + Values can range from [0.0,1.0], inclusive. A value closer + to 1.0 will produce responses that are more varied and + creative, while a value closer to 0.0 will typically result + in more straightforward responses from the model. A low + temperature (~0.2) is usually recommended for + Attributed-Question-Answering use cases. + + This field is a member of `oneof`_ ``_temperature``. + """ + + class AnswerStyle(proto.Enum): + r"""Style for grounded answers. + + Values: + ANSWER_STYLE_UNSPECIFIED (0): + Unspecified answer style. + ABSTRACTIVE (1): + Succint but abstract style. + EXTRACTIVE (2): + Very brief and extractive style. + VERBOSE (3): + Verbose style including extra details. The + response may be formatted as a sentence, + paragraph, multiple paragraphs, or bullet + points, etc. + """ + ANSWER_STYLE_UNSPECIFIED = 0 + ABSTRACTIVE = 1 + EXTRACTIVE = 2 + VERBOSE = 3 + + inline_passages: gag_content.GroundingPassages = proto.Field( + proto.MESSAGE, + number=6, + oneof="grounding_source", + message=gag_content.GroundingPassages, + ) + semantic_retriever: "SemanticRetrieverConfig" = proto.Field( + proto.MESSAGE, + number=7, + oneof="grounding_source", + message="SemanticRetrieverConfig", + ) + model: str = proto.Field( + proto.STRING, + number=1, + ) + contents: MutableSequence[gag_content.Content] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=gag_content.Content, + ) + answer_style: AnswerStyle = proto.Field( + proto.ENUM, + number=5, + enum=AnswerStyle, + ) + safety_settings: MutableSequence[safety.SafetySetting] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message=safety.SafetySetting, + ) + temperature: float = proto.Field( + proto.FLOAT, + number=4, + optional=True, + ) + + +class GenerateAnswerResponse(proto.Message): + r"""Response from the model for a grounded answer. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + answer (google.ai.generativelanguage_v1alpha.types.Candidate): + Candidate answer from the model. + + Note: The model *always* attempts to provide a grounded + answer, even when the answer is unlikely to be answerable + from the given passages. In that case, a low-quality or + ungrounded answer may be provided, along with a low + ``answerable_probability``. + answerable_probability (float): + Output only. The model's estimate of the probability that + its answer is correct and grounded in the input passages. + + A low ``answerable_probability`` indicates that the answer + might not be grounded in the sources. + + When ``answerable_probability`` is low, you may want to: + + - Display a message to the effect of "We couldn’t answer + that question" to the user. + - Fall back to a general-purpose LLM that answers the + question from world knowledge. The threshold and nature + of such fallbacks will depend on individual use cases. + ``0.5`` is a good starting threshold. + + This field is a member of `oneof`_ ``_answerable_probability``. + input_feedback (google.ai.generativelanguage_v1alpha.types.GenerateAnswerResponse.InputFeedback): + Output only. Feedback related to the input data used to + answer the question, as opposed to the model-generated + response to the question. + + The input data can be one or more of the following: + + - Question specified by the last entry in + ``GenerateAnswerRequest.content`` + - Conversation history specified by the other entries in + ``GenerateAnswerRequest.content`` + - Grounding sources + (``GenerateAnswerRequest.semantic_retriever`` or + ``GenerateAnswerRequest.inline_passages``) + + This field is a member of `oneof`_ ``_input_feedback``. + """ + + class InputFeedback(proto.Message): + r"""Feedback related to the input data used to answer the + question, as opposed to the model-generated response to the + question. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + block_reason (google.ai.generativelanguage_v1alpha.types.GenerateAnswerResponse.InputFeedback.BlockReason): + Optional. If set, the input was blocked and + no candidates are returned. Rephrase the input. + + This field is a member of `oneof`_ ``_block_reason``. + safety_ratings (MutableSequence[google.ai.generativelanguage_v1alpha.types.SafetyRating]): + Ratings for safety of the input. + There is at most one rating per category. + """ + + class BlockReason(proto.Enum): + r"""Specifies what was the reason why input was blocked. + + Values: + BLOCK_REASON_UNSPECIFIED (0): + Default value. This value is unused. + SAFETY (1): + Input was blocked due to safety reasons. Inspect + ``safety_ratings`` to understand which safety category + blocked it. + OTHER (2): + Input was blocked due to other reasons. + """ + BLOCK_REASON_UNSPECIFIED = 0 + SAFETY = 1 + OTHER = 2 + + block_reason: "GenerateAnswerResponse.InputFeedback.BlockReason" = proto.Field( + proto.ENUM, + number=1, + optional=True, + enum="GenerateAnswerResponse.InputFeedback.BlockReason", + ) + safety_ratings: MutableSequence[safety.SafetyRating] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=safety.SafetyRating, + ) + + answer: "Candidate" = proto.Field( + proto.MESSAGE, + number=1, + message="Candidate", + ) + answerable_probability: float = proto.Field( + proto.FLOAT, + number=2, + optional=True, + ) + input_feedback: InputFeedback = proto.Field( + proto.MESSAGE, + number=3, + optional=True, + message=InputFeedback, + ) + + +class EmbedContentRequest(proto.Message): + r"""Request containing the ``Content`` for the model to embed. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + model (str): + Required. The model's resource name. This serves as an ID + for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + content (google.ai.generativelanguage_v1alpha.types.Content): + Required. The content to embed. Only the ``parts.text`` + fields will be counted. + task_type (google.ai.generativelanguage_v1alpha.types.TaskType): + Optional. Optional task type for which the embeddings will + be used. Can only be set for ``models/embedding-001``. + + This field is a member of `oneof`_ ``_task_type``. + title (str): + Optional. An optional title for the text. Only applicable + when TaskType is ``RETRIEVAL_DOCUMENT``. + + Note: Specifying a ``title`` for ``RETRIEVAL_DOCUMENT`` + provides better quality embeddings for retrieval. + + This field is a member of `oneof`_ ``_title``. + output_dimensionality (int): + Optional. Optional reduced dimension for the output + embedding. If set, excessive values in the output embedding + are truncated from the end. Supported by newer models since + 2024 only. You cannot set this value if using the earlier + model (``models/embedding-001``). + + This field is a member of `oneof`_ ``_output_dimensionality``. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + content: gag_content.Content = proto.Field( + proto.MESSAGE, + number=2, + message=gag_content.Content, + ) + task_type: "TaskType" = proto.Field( + proto.ENUM, + number=3, + optional=True, + enum="TaskType", + ) + title: str = proto.Field( + proto.STRING, + number=4, + optional=True, + ) + output_dimensionality: int = proto.Field( + proto.INT32, + number=5, + optional=True, + ) + + +class ContentEmbedding(proto.Message): + r"""A list of floats representing an embedding. + + Attributes: + values (MutableSequence[float]): + The embedding values. + """ + + values: MutableSequence[float] = proto.RepeatedField( + proto.FLOAT, + number=1, + ) + + +class EmbedContentResponse(proto.Message): + r"""The response to an ``EmbedContentRequest``. + + Attributes: + embedding (google.ai.generativelanguage_v1alpha.types.ContentEmbedding): + Output only. The embedding generated from the + input content. + """ + + embedding: "ContentEmbedding" = proto.Field( + proto.MESSAGE, + number=1, + message="ContentEmbedding", + ) + + +class BatchEmbedContentsRequest(proto.Message): + r"""Batch request to get embeddings from the model for a list of + prompts. + + Attributes: + model (str): + Required. The model's resource name. This serves as an ID + for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + requests (MutableSequence[google.ai.generativelanguage_v1alpha.types.EmbedContentRequest]): + Required. Embed requests for the batch. The model in each of + these requests must match the model specified + ``BatchEmbedContentsRequest.model``. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + requests: MutableSequence["EmbedContentRequest"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="EmbedContentRequest", + ) + + +class BatchEmbedContentsResponse(proto.Message): + r"""The response to a ``BatchEmbedContentsRequest``. + + Attributes: + embeddings (MutableSequence[google.ai.generativelanguage_v1alpha.types.ContentEmbedding]): + Output only. The embeddings for each request, + in the same order as provided in the batch + request. + """ + + embeddings: MutableSequence["ContentEmbedding"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="ContentEmbedding", + ) + + +class CountTokensRequest(proto.Message): + r"""Counts the number of tokens in the ``prompt`` sent to a model. + + Models may tokenize text differently, so each model may return a + different ``token_count``. + + Attributes: + model (str): + Required. The model's resource name. This serves as an ID + for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + contents (MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]): + Optional. The input given to the model as a prompt. This + field is ignored when ``generate_content_request`` is set. + generate_content_request (google.ai.generativelanguage_v1alpha.types.GenerateContentRequest): + Optional. The overall input given to the ``Model``. This + includes the prompt as well as other model steering + information like `system + instructions `__, + and/or function declarations for `function + calling `__. + ``Model``\ s/\ ``Content``\ s and + ``generate_content_request``\ s are mutually exclusive. You + can either send ``Model`` + ``Content``\ s or a + ``generate_content_request``, but never both. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + contents: MutableSequence[gag_content.Content] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=gag_content.Content, + ) + generate_content_request: "GenerateContentRequest" = proto.Field( + proto.MESSAGE, + number=3, + message="GenerateContentRequest", + ) + + +class CountTokensResponse(proto.Message): + r"""A response from ``CountTokens``. + + It returns the model's ``token_count`` for the ``prompt``. + + Attributes: + total_tokens (int): + The number of tokens that the ``Model`` tokenizes the + ``prompt`` into. Always non-negative. + cached_content_token_count (int): + Number of tokens in the cached part of the + prompt (the cached content). + """ + + total_tokens: int = proto.Field( + proto.INT32, + number=1, + ) + cached_content_token_count: int = proto.Field( + proto.INT32, + number=5, + ) + + +class BidiGenerateContentSetup(proto.Message): + r"""Message to be sent in the first and only first + ``BidiGenerateContentClientMessage``. Contains configuration that + will apply for the duration of the streaming RPC. + + Clients should wait for a ``BidiGenerateContentSetupComplete`` + message before sending any additional messages. + + Attributes: + model (str): + Required. The model's resource name. This serves as an ID + for the Model to use. + + Format: ``models/{model}`` + generation_config (google.ai.generativelanguage_v1alpha.types.GenerationConfig): + Optional. Generation config. + + The following fields are not supported: + + - ``response_logprobs`` + - ``response_mime_type`` + - ``logprobs`` + - ``response_schema`` + - ``stop_sequence`` + - ``routing_config`` + - ``audio_timestamp`` + system_instruction (google.ai.generativelanguage_v1alpha.types.Content): + Optional. The user provided system + instructions for the model. + Note: Only text should be used in parts and + content in each part will be in a separate + paragraph. + tools (MutableSequence[google.ai.generativelanguage_v1alpha.types.Tool]): + Optional. A list of ``Tools`` the model may use to generate + the next response. + + A ``Tool`` is a piece of code that enables the system to + interact with external systems to perform an action, or set + of actions, outside of knowledge and scope of the model. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + generation_config: "GenerationConfig" = proto.Field( + proto.MESSAGE, + number=2, + message="GenerationConfig", + ) + system_instruction: gag_content.Content = proto.Field( + proto.MESSAGE, + number=3, + message=gag_content.Content, + ) + tools: MutableSequence[gag_content.Tool] = proto.RepeatedField( + proto.MESSAGE, + number=4, + message=gag_content.Tool, + ) + + +class BidiGenerateContentClientContent(proto.Message): + r"""Incremental update of the current conversation delivered from + the client. All of the content here is unconditionally appended + to the conversation history and used as part of the prompt to + the model to generate content. + + A message here will interrupt any current model generation. + + Attributes: + turns (MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]): + Optional. The content appended to the current + conversation with the model. + For single-turn queries, this is a single + instance. For multi-turn queries, this is a + repeated field that contains conversation + history and the latest request. + turn_complete (bool): + Optional. If true, indicates that the server + content generation should start with the + currently accumulated prompt. Otherwise, the + server awaits additional messages before + starting generation. + """ + + turns: MutableSequence[gag_content.Content] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=gag_content.Content, + ) + turn_complete: bool = proto.Field( + proto.BOOL, + number=2, + ) + + +class BidiGenerateContentRealtimeInput(proto.Message): + r"""User input that is sent in real time. + + This is different from + [BidiGenerateContentClientContent][google.ai.generativelanguage.v1alpha.BidiGenerateContentClientContent] + in a few ways: + + - Can be sent continuously without interruption to model + generation. + - If there is a need to mix data interleaved across the + [BidiGenerateContentClientContent][google.ai.generativelanguage.v1alpha.BidiGenerateContentClientContent] + and the + [BidiGenerateContentRealtimeInput][google.ai.generativelanguage.v1alpha.BidiGenerateContentRealtimeInput], + the server attempts to optimize for best response, but there are + no guarantees. + - End of turn is not explicitly specified, but is rather derived + from user activity (for example, end of speech). + - Even before the end of turn, the data is processed incrementally + to optimize for a fast start of the response from the model. + - Is always direct user input that is sent in real time. Can be + sent continuously without interruptions. The model automatically + detects the beginning and the end of user speech and starts or + terminates streaming the response accordingly. Data is processed + incrementally as it arrives, minimizing latency. + + Attributes: + media_chunks (MutableSequence[google.ai.generativelanguage_v1alpha.types.Blob]): + Optional. Inlined bytes data for media input. + """ + + media_chunks: MutableSequence[gag_content.Blob] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=gag_content.Blob, + ) + + +class BidiGenerateContentToolResponse(proto.Message): + r"""Client generated response to a ``ToolCall`` received from the + server. Individual ``FunctionResponse`` objects are matched to the + respective ``FunctionCall`` objects by the ``id`` field. + + Note that in the unary and server-streaming GenerateContent APIs + function calling happens by exchanging the ``Content`` parts, while + in the bidi GenerateContent APIs function calling happens over these + dedicated set of messages. + + Attributes: + function_responses (MutableSequence[google.ai.generativelanguage_v1alpha.types.FunctionResponse]): + Optional. The response to the function calls. + """ + + function_responses: MutableSequence[ + gag_content.FunctionResponse + ] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=gag_content.FunctionResponse, + ) + + +class BidiGenerateContentClientMessage(proto.Message): + r"""Messages sent by the client in the BidiGenerateContent call. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + setup (google.ai.generativelanguage_v1alpha.types.BidiGenerateContentSetup): + Optional. Session configuration sent in the + first and only first client message. + + This field is a member of `oneof`_ ``message_type``. + client_content (google.ai.generativelanguage_v1alpha.types.BidiGenerateContentClientContent): + Optional. Incremental update of the current + conversation delivered from the client. + + This field is a member of `oneof`_ ``message_type``. + realtime_input (google.ai.generativelanguage_v1alpha.types.BidiGenerateContentRealtimeInput): + Optional. User input that is sent in real + time. + + This field is a member of `oneof`_ ``message_type``. + tool_response (google.ai.generativelanguage_v1alpha.types.BidiGenerateContentToolResponse): + Optional. Response to a ``ToolCallMessage`` received from + the server. + + This field is a member of `oneof`_ ``message_type``. + """ + + setup: "BidiGenerateContentSetup" = proto.Field( + proto.MESSAGE, + number=1, + oneof="message_type", + message="BidiGenerateContentSetup", + ) + client_content: "BidiGenerateContentClientContent" = proto.Field( + proto.MESSAGE, + number=2, + oneof="message_type", + message="BidiGenerateContentClientContent", + ) + realtime_input: "BidiGenerateContentRealtimeInput" = proto.Field( + proto.MESSAGE, + number=3, + oneof="message_type", + message="BidiGenerateContentRealtimeInput", + ) + tool_response: "BidiGenerateContentToolResponse" = proto.Field( + proto.MESSAGE, + number=4, + oneof="message_type", + message="BidiGenerateContentToolResponse", + ) + + +class BidiGenerateContentSetupComplete(proto.Message): + r"""Sent in response to a ``BidiGenerateContentSetup`` message from the + client. + + """ + + +class BidiGenerateContentServerContent(proto.Message): + r"""Incremental server update generated by the model in response + to client messages. + + Content is generated as quickly as possible, and not in real + time. Clients may choose to buffer and play it out in real time. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + model_turn (google.ai.generativelanguage_v1alpha.types.Content): + Output only. The content that the model has + generated as part of the current conversation + with the user. + + This field is a member of `oneof`_ ``_model_turn``. + turn_complete (bool): + Output only. If true, indicates that the model is done + generating. Generation will only start in response to + additional client messages. Can be set alongside + ``content``, indicating that the ``content`` is the last in + the turn. + interrupted (bool): + Output only. If true, indicates that a client + message has interrupted current model + generation. If the client is playing out the + content in real time, this is a good signal to + stop and empty the current playback queue. + grounding_metadata (google.ai.generativelanguage_v1alpha.types.GroundingMetadata): + Output only. Grounding metadata for the + generated content. + """ + + model_turn: gag_content.Content = proto.Field( + proto.MESSAGE, + number=1, + optional=True, + message=gag_content.Content, + ) + turn_complete: bool = proto.Field( + proto.BOOL, + number=2, + ) + interrupted: bool = proto.Field( + proto.BOOL, + number=3, + ) + grounding_metadata: "GroundingMetadata" = proto.Field( + proto.MESSAGE, + number=4, + message="GroundingMetadata", + ) + + +class BidiGenerateContentToolCall(proto.Message): + r"""Request for the client to execute the ``function_calls`` and return + the responses with the matching ``id``\ s. + + Attributes: + function_calls (MutableSequence[google.ai.generativelanguage_v1alpha.types.FunctionCall]): + Output only. The function call to be + executed. + """ + + function_calls: MutableSequence[gag_content.FunctionCall] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=gag_content.FunctionCall, + ) + + +class BidiGenerateContentToolCallCancellation(proto.Message): + r"""Notification for the client that a previously issued + ``ToolCallMessage`` with the specified ``id``\ s should have been + not executed and should be cancelled. If there were side-effects to + those tool calls, clients may attempt to undo the tool calls. This + message occurs only in cases where the clients interrupt server + turns. + + Attributes: + ids (MutableSequence[str]): + Output only. The ids of the tool calls to be + cancelled. + """ + + ids: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=1, + ) + + +class BidiGenerateContentServerMessage(proto.Message): + r"""Response message for the BidiGenerateContent call. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + setup_complete (google.ai.generativelanguage_v1alpha.types.BidiGenerateContentSetupComplete): + Output only. Sent in response to a + ``BidiGenerateContentSetup`` message from the client when + setup is complete. + + This field is a member of `oneof`_ ``message_type``. + server_content (google.ai.generativelanguage_v1alpha.types.BidiGenerateContentServerContent): + Output only. Content generated by the model + in response to client messages. + + This field is a member of `oneof`_ ``message_type``. + tool_call (google.ai.generativelanguage_v1alpha.types.BidiGenerateContentToolCall): + Output only. Request for the client to execute the + ``function_calls`` and return the responses with the + matching ``id``\ s. + + This field is a member of `oneof`_ ``message_type``. + tool_call_cancellation (google.ai.generativelanguage_v1alpha.types.BidiGenerateContentToolCallCancellation): + Output only. Notification for the client that a previously + issued ``ToolCallMessage`` with the specified ``id``\ s + should be cancelled. + + This field is a member of `oneof`_ ``message_type``. + """ + + setup_complete: "BidiGenerateContentSetupComplete" = proto.Field( + proto.MESSAGE, + number=2, + oneof="message_type", + message="BidiGenerateContentSetupComplete", + ) + server_content: "BidiGenerateContentServerContent" = proto.Field( + proto.MESSAGE, + number=3, + oneof="message_type", + message="BidiGenerateContentServerContent", + ) + tool_call: "BidiGenerateContentToolCall" = proto.Field( + proto.MESSAGE, + number=4, + oneof="message_type", + message="BidiGenerateContentToolCall", + ) + tool_call_cancellation: "BidiGenerateContentToolCallCancellation" = proto.Field( + proto.MESSAGE, + number=5, + oneof="message_type", + message="BidiGenerateContentToolCallCancellation", + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/model.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/model.py new file mode 100644 index 000000000000..9549cefb8fb6 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/model.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "Model", + }, +) + + +class Model(proto.Message): + r"""Information about a Generative Language Model. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + name (str): + Required. The resource name of the ``Model``. Refer to + `Model + variants `__ + for all allowed values. + + Format: ``models/{model}`` with a ``{model}`` naming + convention of: + + - "{base_model_id}-{version}" + + Examples: + + - ``models/gemini-1.5-flash-001`` + base_model_id (str): + Required. The name of the base model, pass this to the + generation request. + + Examples: + + - ``gemini-1.5-flash`` + version (str): + Required. The version number of the model. + + This represents the major version (``1.0`` or ``1.5``) + display_name (str): + The human-readable name of the model. E.g. + "Gemini 1.5 Flash". + The name can be up to 128 characters long and + can consist of any UTF-8 characters. + description (str): + A short description of the model. + input_token_limit (int): + Maximum number of input tokens allowed for + this model. + output_token_limit (int): + Maximum number of output tokens available for + this model. + supported_generation_methods (MutableSequence[str]): + The model's supported generation methods. + + The corresponding API method names are defined as Pascal + case strings, such as ``generateMessage`` and + ``generateContent``. + temperature (float): + Controls the randomness of the output. + + Values can range over ``[0.0,max_temperature]``, inclusive. + A higher value will produce responses that are more varied, + while a value closer to ``0.0`` will typically result in + less surprising responses from the model. This value + specifies default to be used by the backend while making the + call to the model. + + This field is a member of `oneof`_ ``_temperature``. + max_temperature (float): + The maximum temperature this model can use. + + This field is a member of `oneof`_ ``_max_temperature``. + top_p (float): + For `Nucleus + sampling `__. + + Nucleus sampling considers the smallest set of tokens whose + probability sum is at least ``top_p``. This value specifies + default to be used by the backend while making the call to + the model. + + This field is a member of `oneof`_ ``_top_p``. + top_k (int): + For Top-k sampling. + + Top-k sampling considers the set of ``top_k`` most probable + tokens. This value specifies default to be used by the + backend while making the call to the model. If empty, + indicates the model doesn't use top-k sampling, and + ``top_k`` isn't allowed as a generation parameter. + + This field is a member of `oneof`_ ``_top_k``. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + base_model_id: str = proto.Field( + proto.STRING, + number=2, + ) + version: str = proto.Field( + proto.STRING, + number=3, + ) + display_name: str = proto.Field( + proto.STRING, + number=4, + ) + description: str = proto.Field( + proto.STRING, + number=5, + ) + input_token_limit: int = proto.Field( + proto.INT32, + number=6, + ) + output_token_limit: int = proto.Field( + proto.INT32, + number=7, + ) + supported_generation_methods: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=8, + ) + temperature: float = proto.Field( + proto.FLOAT, + number=9, + optional=True, + ) + max_temperature: float = proto.Field( + proto.FLOAT, + number=13, + optional=True, + ) + top_p: float = proto.Field( + proto.FLOAT, + number=10, + optional=True, + ) + top_k: int = proto.Field( + proto.INT32, + number=11, + optional=True, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/model_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/model_service.py new file mode 100644 index 000000000000..afa4fc0bc748 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/model_service.py @@ -0,0 +1,332 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +from google.protobuf import field_mask_pb2 # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import tuned_model as gag_tuned_model +from google.ai.generativelanguage_v1alpha.types import model + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "GetModelRequest", + "ListModelsRequest", + "ListModelsResponse", + "GetTunedModelRequest", + "ListTunedModelsRequest", + "ListTunedModelsResponse", + "CreateTunedModelRequest", + "CreateTunedModelMetadata", + "UpdateTunedModelRequest", + "DeleteTunedModelRequest", + }, +) + + +class GetModelRequest(proto.Message): + r"""Request for getting information about a specific Model. + + Attributes: + name (str): + Required. The resource name of the model. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class ListModelsRequest(proto.Message): + r"""Request for listing all Models. + + Attributes: + page_size (int): + The maximum number of ``Models`` to return (per page). + + If unspecified, 50 models will be returned per page. This + method returns at most 1000 models per page, even if you + pass a larger page_size. + page_token (str): + A page token, received from a previous ``ListModels`` call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListModels`` must match the call that provided the page + token. + """ + + page_size: int = proto.Field( + proto.INT32, + number=2, + ) + page_token: str = proto.Field( + proto.STRING, + number=3, + ) + + +class ListModelsResponse(proto.Message): + r"""Response from ``ListModel`` containing a paginated list of Models. + + Attributes: + models (MutableSequence[google.ai.generativelanguage_v1alpha.types.Model]): + The returned Models. + next_page_token (str): + A token, which can be sent as ``page_token`` to retrieve the + next page. + + If this field is omitted, there are no more pages. + """ + + @property + def raw_page(self): + return self + + models: MutableSequence[model.Model] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=model.Model, + ) + next_page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +class GetTunedModelRequest(proto.Message): + r"""Request for getting information about a specific Model. + + Attributes: + name (str): + Required. The resource name of the model. + + Format: ``tunedModels/my-model-id`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class ListTunedModelsRequest(proto.Message): + r"""Request for listing TunedModels. + + Attributes: + page_size (int): + Optional. The maximum number of ``TunedModels`` to return + (per page). The service may return fewer tuned models. + + If unspecified, at most 10 tuned models will be returned. + This method returns at most 1000 models per page, even if + you pass a larger page_size. + page_token (str): + Optional. A page token, received from a previous + ``ListTunedModels`` call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListTunedModels`` must match the call that provided the + page token. + filter (str): + Optional. A filter is a full text search over + the tuned model's description and display name. + By default, results will not include tuned + models shared with everyone. + + Additional operators: + + - owner:me + - writers:me + - readers:me + - readers:everyone + + Examples: + + "owner:me" returns all tuned models to which + caller has owner role "readers:me" returns all + tuned models to which caller has reader role + "readers:everyone" returns all tuned models that + are shared with everyone + """ + + page_size: int = proto.Field( + proto.INT32, + number=1, + ) + page_token: str = proto.Field( + proto.STRING, + number=2, + ) + filter: str = proto.Field( + proto.STRING, + number=3, + ) + + +class ListTunedModelsResponse(proto.Message): + r"""Response from ``ListTunedModels`` containing a paginated list of + Models. + + Attributes: + tuned_models (MutableSequence[google.ai.generativelanguage_v1alpha.types.TunedModel]): + The returned Models. + next_page_token (str): + A token, which can be sent as ``page_token`` to retrieve the + next page. + + If this field is omitted, there are no more pages. + """ + + @property + def raw_page(self): + return self + + tuned_models: MutableSequence[gag_tuned_model.TunedModel] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=gag_tuned_model.TunedModel, + ) + next_page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +class CreateTunedModelRequest(proto.Message): + r"""Request to create a TunedModel. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + tuned_model_id (str): + Optional. The unique id for the tuned model if specified. + This value should be up to 40 characters, the first + character must be a letter, the last could be a letter or a + number. The id must match the regular expression: + ``[a-z]([a-z0-9-]{0,38}[a-z0-9])?``. + + This field is a member of `oneof`_ ``_tuned_model_id``. + tuned_model (google.ai.generativelanguage_v1alpha.types.TunedModel): + Required. The tuned model to create. + """ + + tuned_model_id: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + tuned_model: gag_tuned_model.TunedModel = proto.Field( + proto.MESSAGE, + number=2, + message=gag_tuned_model.TunedModel, + ) + + +class CreateTunedModelMetadata(proto.Message): + r"""Metadata about the state and progress of creating a tuned + model returned from the long-running operation + + Attributes: + tuned_model (str): + Name of the tuned model associated with the + tuning operation. + total_steps (int): + The total number of tuning steps. + completed_steps (int): + The number of steps completed. + completed_percent (float): + The completed percentage for the tuning + operation. + snapshots (MutableSequence[google.ai.generativelanguage_v1alpha.types.TuningSnapshot]): + Metrics collected during tuning. + """ + + tuned_model: str = proto.Field( + proto.STRING, + number=5, + ) + total_steps: int = proto.Field( + proto.INT32, + number=1, + ) + completed_steps: int = proto.Field( + proto.INT32, + number=2, + ) + completed_percent: float = proto.Field( + proto.FLOAT, + number=3, + ) + snapshots: MutableSequence[gag_tuned_model.TuningSnapshot] = proto.RepeatedField( + proto.MESSAGE, + number=4, + message=gag_tuned_model.TuningSnapshot, + ) + + +class UpdateTunedModelRequest(proto.Message): + r"""Request to update a TunedModel. + + Attributes: + tuned_model (google.ai.generativelanguage_v1alpha.types.TunedModel): + Required. The tuned model to update. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Optional. The list of fields to update. + """ + + tuned_model: gag_tuned_model.TunedModel = proto.Field( + proto.MESSAGE, + number=1, + message=gag_tuned_model.TunedModel, + ) + update_mask: field_mask_pb2.FieldMask = proto.Field( + proto.MESSAGE, + number=2, + message=field_mask_pb2.FieldMask, + ) + + +class DeleteTunedModelRequest(proto.Message): + r"""Request to delete a TunedModel. + + Attributes: + name (str): + Required. The resource name of the model. Format: + ``tunedModels/my-model-id`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/permission.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/permission.py new file mode 100644 index 000000000000..a73758c6c688 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/permission.py @@ -0,0 +1,141 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "Permission", + }, +) + + +class Permission(proto.Message): + r"""Permission resource grants user, group or the rest of the + world access to the PaLM API resource (e.g. a tuned model, + corpus). + + A role is a collection of permitted operations that allows users + to perform specific actions on PaLM API resources. To make them + available to users, groups, or service accounts, you assign + roles. When you assign a role, you grant permissions that the + role contains. + + There are three concentric roles. Each role is a superset of the + previous role's permitted operations: + + - reader can use the resource (e.g. tuned model, corpus) for + inference + - writer has reader's permissions and additionally can edit and + share + - owner has writer's permissions and additionally can delete + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + name (str): + Output only. Identifier. The permission name. A unique name + will be generated on create. Examples: + tunedModels/{tuned_model}/permissions/{permission} + corpora/{corpus}/permissions/{permission} Output only. + grantee_type (google.ai.generativelanguage_v1alpha.types.Permission.GranteeType): + Optional. Immutable. The type of the grantee. + + This field is a member of `oneof`_ ``_grantee_type``. + email_address (str): + Optional. Immutable. The email address of the + user of group which this permission refers. + Field is not set when permission's grantee type + is EVERYONE. + + This field is a member of `oneof`_ ``_email_address``. + role (google.ai.generativelanguage_v1alpha.types.Permission.Role): + Required. The role granted by this + permission. + + This field is a member of `oneof`_ ``_role``. + """ + + class GranteeType(proto.Enum): + r"""Defines types of the grantee of this permission. + + Values: + GRANTEE_TYPE_UNSPECIFIED (0): + The default value. This value is unused. + USER (1): + Represents a user. When set, you must provide email_address + for the user. + GROUP (2): + Represents a group. When set, you must provide email_address + for the group. + EVERYONE (3): + Represents access to everyone. No extra + information is required. + """ + GRANTEE_TYPE_UNSPECIFIED = 0 + USER = 1 + GROUP = 2 + EVERYONE = 3 + + class Role(proto.Enum): + r"""Defines the role granted by this permission. + + Values: + ROLE_UNSPECIFIED (0): + The default value. This value is unused. + OWNER (1): + Owner can use, update, share and delete the + resource. + WRITER (2): + Writer can use, update and share the + resource. + READER (3): + Reader can use the resource. + """ + ROLE_UNSPECIFIED = 0 + OWNER = 1 + WRITER = 2 + READER = 3 + + name: str = proto.Field( + proto.STRING, + number=1, + ) + grantee_type: GranteeType = proto.Field( + proto.ENUM, + number=2, + optional=True, + enum=GranteeType, + ) + email_address: str = proto.Field( + proto.STRING, + number=3, + optional=True, + ) + role: Role = proto.Field( + proto.ENUM, + number=4, + optional=True, + enum=Role, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/permission_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/permission_service.py new file mode 100644 index 000000000000..1e64faf3fe0d --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/permission_service.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +from google.protobuf import field_mask_pb2 # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import permission as gag_permission + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "CreatePermissionRequest", + "GetPermissionRequest", + "ListPermissionsRequest", + "ListPermissionsResponse", + "UpdatePermissionRequest", + "DeletePermissionRequest", + "TransferOwnershipRequest", + "TransferOwnershipResponse", + }, +) + + +class CreatePermissionRequest(proto.Message): + r"""Request to create a ``Permission``. + + Attributes: + parent (str): + Required. The parent resource of the ``Permission``. + Formats: ``tunedModels/{tuned_model}`` ``corpora/{corpus}`` + permission (google.ai.generativelanguage_v1alpha.types.Permission): + Required. The permission to create. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + permission: gag_permission.Permission = proto.Field( + proto.MESSAGE, + number=2, + message=gag_permission.Permission, + ) + + +class GetPermissionRequest(proto.Message): + r"""Request for getting information about a specific ``Permission``. + + Attributes: + name (str): + Required. The resource name of the permission. + + Formats: + ``tunedModels/{tuned_model}/permissions/{permission}`` + ``corpora/{corpus}/permissions/{permission}`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class ListPermissionsRequest(proto.Message): + r"""Request for listing permissions. + + Attributes: + parent (str): + Required. The parent resource of the permissions. Formats: + ``tunedModels/{tuned_model}`` ``corpora/{corpus}`` + page_size (int): + Optional. The maximum number of ``Permission``\ s to return + (per page). The service may return fewer permissions. + + If unspecified, at most 10 permissions will be returned. + This method returns at most 1000 permissions per page, even + if you pass larger page_size. + page_token (str): + Optional. A page token, received from a previous + ``ListPermissions`` call. + + Provide the ``page_token`` returned by one request as an + argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListPermissions`` must match the call that provided the + page token. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + page_size: int = proto.Field( + proto.INT32, + number=2, + ) + page_token: str = proto.Field( + proto.STRING, + number=3, + ) + + +class ListPermissionsResponse(proto.Message): + r"""Response from ``ListPermissions`` containing a paginated list of + permissions. + + Attributes: + permissions (MutableSequence[google.ai.generativelanguage_v1alpha.types.Permission]): + Returned permissions. + next_page_token (str): + A token, which can be sent as ``page_token`` to retrieve the + next page. + + If this field is omitted, there are no more pages. + """ + + @property + def raw_page(self): + return self + + permissions: MutableSequence[gag_permission.Permission] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=gag_permission.Permission, + ) + next_page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +class UpdatePermissionRequest(proto.Message): + r"""Request to update the ``Permission``. + + Attributes: + permission (google.ai.generativelanguage_v1alpha.types.Permission): + Required. The permission to update. + + The permission's ``name`` field is used to identify the + permission to update. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The list of fields to update. Accepted ones: + + - role (``Permission.role`` field) + """ + + permission: gag_permission.Permission = proto.Field( + proto.MESSAGE, + number=1, + message=gag_permission.Permission, + ) + update_mask: field_mask_pb2.FieldMask = proto.Field( + proto.MESSAGE, + number=2, + message=field_mask_pb2.FieldMask, + ) + + +class DeletePermissionRequest(proto.Message): + r"""Request to delete the ``Permission``. + + Attributes: + name (str): + Required. The resource name of the permission. Formats: + ``tunedModels/{tuned_model}/permissions/{permission}`` + ``corpora/{corpus}/permissions/{permission}`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class TransferOwnershipRequest(proto.Message): + r"""Request to transfer the ownership of the tuned model. + + Attributes: + name (str): + Required. The resource name of the tuned model to transfer + ownership. + + Format: ``tunedModels/my-model-id`` + email_address (str): + Required. The email address of the user to + whom the tuned model is being transferred to. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + email_address: str = proto.Field( + proto.STRING, + number=2, + ) + + +class TransferOwnershipResponse(proto.Message): + r"""Response from ``TransferOwnership``.""" + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/prediction_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/prediction_service.py new file mode 100644 index 000000000000..bb5326c4a9db --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/prediction_service.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +from google.protobuf import struct_pb2 # type: ignore +import proto # type: ignore + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "PredictRequest", + "PredictResponse", + }, +) + + +class PredictRequest(proto.Message): + r"""Request message for + [PredictionService.Predict][google.ai.generativelanguage.v1alpha.PredictionService.Predict]. + + Attributes: + model (str): + Required. The name of the model for prediction. Format: + ``name=models/{model}``. + instances (MutableSequence[google.protobuf.struct_pb2.Value]): + Required. The instances that are the input to + the prediction call. + parameters (google.protobuf.struct_pb2.Value): + Optional. The parameters that govern the + prediction call. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + instances: MutableSequence[struct_pb2.Value] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=struct_pb2.Value, + ) + parameters: struct_pb2.Value = proto.Field( + proto.MESSAGE, + number=3, + message=struct_pb2.Value, + ) + + +class PredictResponse(proto.Message): + r"""Response message for [PredictionService.Predict]. + + Attributes: + predictions (MutableSequence[google.protobuf.struct_pb2.Value]): + The outputs of the prediction call. + """ + + predictions: MutableSequence[struct_pb2.Value] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=struct_pb2.Value, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/retriever.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/retriever.py new file mode 100644 index 000000000000..53b3a1bfb845 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/retriever.py @@ -0,0 +1,411 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +from google.protobuf import timestamp_pb2 # type: ignore +import proto # type: ignore + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "Corpus", + "Document", + "StringList", + "CustomMetadata", + "MetadataFilter", + "Condition", + "Chunk", + "ChunkData", + }, +) + + +class Corpus(proto.Message): + r"""A ``Corpus`` is a collection of ``Document``\ s. A project can + create up to 5 corpora. + + Attributes: + name (str): + Immutable. Identifier. The ``Corpus`` resource name. The ID + (name excluding the "corpora/" prefix) can contain up to 40 + characters that are lowercase alphanumeric or dashes (-). + The ID cannot start or end with a dash. If the name is empty + on create, a unique name will be derived from + ``display_name`` along with a 12 character random suffix. + Example: ``corpora/my-awesome-corpora-123a456b789c`` + display_name (str): + Optional. The human-readable display name for the + ``Corpus``. The display name must be no more than 512 + characters in length, including spaces. Example: "Docs on + Semantic Retriever". + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The Timestamp of when the ``Corpus`` was + created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The Timestamp of when the ``Corpus`` was last + updated. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + display_name: str = proto.Field( + proto.STRING, + number=2, + ) + create_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=3, + message=timestamp_pb2.Timestamp, + ) + update_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=4, + message=timestamp_pb2.Timestamp, + ) + + +class Document(proto.Message): + r"""A ``Document`` is a collection of ``Chunk``\ s. A ``Corpus`` can + have a maximum of 10,000 ``Document``\ s. + + Attributes: + name (str): + Immutable. Identifier. The ``Document`` resource name. The + ID (name excluding the "corpora/*/documents/" prefix) can + contain up to 40 characters that are lowercase alphanumeric + or dashes (-). The ID cannot start or end with a dash. If + the name is empty on create, a unique name will be derived + from ``display_name`` along with a 12 character random + suffix. Example: + ``corpora/{corpus_id}/documents/my-awesome-doc-123a456b789c`` + display_name (str): + Optional. The human-readable display name for the + ``Document``. The display name must be no more than 512 + characters in length, including spaces. Example: "Semantic + Retriever Documentation". + custom_metadata (MutableSequence[google.ai.generativelanguage_v1alpha.types.CustomMetadata]): + Optional. User provided custom metadata stored as key-value + pairs used for querying. A ``Document`` can have a maximum + of 20 ``CustomMetadata``. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The Timestamp of when the ``Document`` was last + updated. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The Timestamp of when the ``Document`` was + created. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + display_name: str = proto.Field( + proto.STRING, + number=2, + ) + custom_metadata: MutableSequence["CustomMetadata"] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="CustomMetadata", + ) + update_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=4, + message=timestamp_pb2.Timestamp, + ) + create_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=5, + message=timestamp_pb2.Timestamp, + ) + + +class StringList(proto.Message): + r"""User provided string values assigned to a single metadata + key. + + Attributes: + values (MutableSequence[str]): + The string values of the metadata to store. + """ + + values: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=1, + ) + + +class CustomMetadata(proto.Message): + r"""User provided metadata stored as key-value pairs. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + string_value (str): + The string value of the metadata to store. + + This field is a member of `oneof`_ ``value``. + string_list_value (google.ai.generativelanguage_v1alpha.types.StringList): + The StringList value of the metadata to + store. + + This field is a member of `oneof`_ ``value``. + numeric_value (float): + The numeric value of the metadata to store. + + This field is a member of `oneof`_ ``value``. + key (str): + Required. The key of the metadata to store. + """ + + string_value: str = proto.Field( + proto.STRING, + number=2, + oneof="value", + ) + string_list_value: "StringList" = proto.Field( + proto.MESSAGE, + number=6, + oneof="value", + message="StringList", + ) + numeric_value: float = proto.Field( + proto.FLOAT, + number=7, + oneof="value", + ) + key: str = proto.Field( + proto.STRING, + number=1, + ) + + +class MetadataFilter(proto.Message): + r"""User provided filter to limit retrieval based on ``Chunk`` or + ``Document`` level metadata values. Example (genre = drama OR genre + = action): key = "document.custom_metadata.genre" conditions = + [{string_value = "drama", operation = EQUAL}, {string_value = + "action", operation = EQUAL}] + + Attributes: + key (str): + Required. The key of the metadata to filter + on. + conditions (MutableSequence[google.ai.generativelanguage_v1alpha.types.Condition]): + Required. The ``Condition``\ s for the given key that will + trigger this filter. Multiple ``Condition``\ s are joined by + logical ORs. + """ + + key: str = proto.Field( + proto.STRING, + number=1, + ) + conditions: MutableSequence["Condition"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="Condition", + ) + + +class Condition(proto.Message): + r"""Filter condition applicable to a single key. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + string_value (str): + The string value to filter the metadata on. + + This field is a member of `oneof`_ ``value``. + numeric_value (float): + The numeric value to filter the metadata on. + + This field is a member of `oneof`_ ``value``. + operation (google.ai.generativelanguage_v1alpha.types.Condition.Operator): + Required. Operator applied to the given + key-value pair to trigger the condition. + """ + + class Operator(proto.Enum): + r"""Defines the valid operators that can be applied to a + key-value pair. + + Values: + OPERATOR_UNSPECIFIED (0): + The default value. This value is unused. + LESS (1): + Supported by numeric. + LESS_EQUAL (2): + Supported by numeric. + EQUAL (3): + Supported by numeric & string. + GREATER_EQUAL (4): + Supported by numeric. + GREATER (5): + Supported by numeric. + NOT_EQUAL (6): + Supported by numeric & string. + INCLUDES (7): + Supported by string only when ``CustomMetadata`` value type + for the given key has a ``string_list_value``. + EXCLUDES (8): + Supported by string only when ``CustomMetadata`` value type + for the given key has a ``string_list_value``. + """ + OPERATOR_UNSPECIFIED = 0 + LESS = 1 + LESS_EQUAL = 2 + EQUAL = 3 + GREATER_EQUAL = 4 + GREATER = 5 + NOT_EQUAL = 6 + INCLUDES = 7 + EXCLUDES = 8 + + string_value: str = proto.Field( + proto.STRING, + number=1, + oneof="value", + ) + numeric_value: float = proto.Field( + proto.FLOAT, + number=6, + oneof="value", + ) + operation: Operator = proto.Field( + proto.ENUM, + number=5, + enum=Operator, + ) + + +class Chunk(proto.Message): + r"""A ``Chunk`` is a subpart of a ``Document`` that is treated as an + independent unit for the purposes of vector representation and + storage. A ``Corpus`` can have a maximum of 1 million ``Chunk``\ s. + + Attributes: + name (str): + Immutable. Identifier. The ``Chunk`` resource name. The ID + (name excluding the "corpora/*/documents/*/chunks/" prefix) + can contain up to 40 characters that are lowercase + alphanumeric or dashes (-). The ID cannot start or end with + a dash. If the name is empty on create, a random + 12-character unique ID will be generated. Example: + ``corpora/{corpus_id}/documents/{document_id}/chunks/123a456b789c`` + data (google.ai.generativelanguage_v1alpha.types.ChunkData): + Required. The content for the ``Chunk``, such as the text + string. The maximum number of tokens per chunk is 2043. + custom_metadata (MutableSequence[google.ai.generativelanguage_v1alpha.types.CustomMetadata]): + Optional. User provided custom metadata stored as key-value + pairs. The maximum number of ``CustomMetadata`` per chunk is + 20. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The Timestamp of when the ``Chunk`` was + created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The Timestamp of when the ``Chunk`` was last + updated. + state (google.ai.generativelanguage_v1alpha.types.Chunk.State): + Output only. Current state of the ``Chunk``. + """ + + class State(proto.Enum): + r"""States for the lifecycle of a ``Chunk``. + + Values: + STATE_UNSPECIFIED (0): + The default value. This value is used if the + state is omitted. + STATE_PENDING_PROCESSING (1): + ``Chunk`` is being processed (embedding and vector storage). + STATE_ACTIVE (2): + ``Chunk`` is processed and available for querying. + STATE_FAILED (10): + ``Chunk`` failed processing. + """ + STATE_UNSPECIFIED = 0 + STATE_PENDING_PROCESSING = 1 + STATE_ACTIVE = 2 + STATE_FAILED = 10 + + name: str = proto.Field( + proto.STRING, + number=1, + ) + data: "ChunkData" = proto.Field( + proto.MESSAGE, + number=2, + message="ChunkData", + ) + custom_metadata: MutableSequence["CustomMetadata"] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="CustomMetadata", + ) + create_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=4, + message=timestamp_pb2.Timestamp, + ) + update_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=5, + message=timestamp_pb2.Timestamp, + ) + state: State = proto.Field( + proto.ENUM, + number=6, + enum=State, + ) + + +class ChunkData(proto.Message): + r"""Extracted data that represents the ``Chunk`` content. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + string_value (str): + The ``Chunk`` content as a string. The maximum number of + tokens per chunk is 2043. + + This field is a member of `oneof`_ ``data``. + """ + + string_value: str = proto.Field( + proto.STRING, + number=1, + oneof="data", + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/retriever_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/retriever_service.py new file mode 100644 index 000000000000..0aa460d1e3b3 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/retriever_service.py @@ -0,0 +1,793 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +from google.protobuf import field_mask_pb2 # type: ignore +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import retriever + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "CreateCorpusRequest", + "GetCorpusRequest", + "UpdateCorpusRequest", + "DeleteCorpusRequest", + "ListCorporaRequest", + "ListCorporaResponse", + "QueryCorpusRequest", + "QueryCorpusResponse", + "RelevantChunk", + "CreateDocumentRequest", + "GetDocumentRequest", + "UpdateDocumentRequest", + "DeleteDocumentRequest", + "ListDocumentsRequest", + "ListDocumentsResponse", + "QueryDocumentRequest", + "QueryDocumentResponse", + "CreateChunkRequest", + "BatchCreateChunksRequest", + "BatchCreateChunksResponse", + "GetChunkRequest", + "UpdateChunkRequest", + "BatchUpdateChunksRequest", + "BatchUpdateChunksResponse", + "DeleteChunkRequest", + "BatchDeleteChunksRequest", + "ListChunksRequest", + "ListChunksResponse", + }, +) + + +class CreateCorpusRequest(proto.Message): + r"""Request to create a ``Corpus``. + + Attributes: + corpus (google.ai.generativelanguage_v1alpha.types.Corpus): + Required. The ``Corpus`` to create. + """ + + corpus: retriever.Corpus = proto.Field( + proto.MESSAGE, + number=1, + message=retriever.Corpus, + ) + + +class GetCorpusRequest(proto.Message): + r"""Request for getting information about a specific ``Corpus``. + + Attributes: + name (str): + Required. The name of the ``Corpus``. Example: + ``corpora/my-corpus-123`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class UpdateCorpusRequest(proto.Message): + r"""Request to update a ``Corpus``. + + Attributes: + corpus (google.ai.generativelanguage_v1alpha.types.Corpus): + Required. The ``Corpus`` to update. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The list of fields to update. Currently, this only + supports updating ``display_name``. + """ + + corpus: retriever.Corpus = proto.Field( + proto.MESSAGE, + number=1, + message=retriever.Corpus, + ) + update_mask: field_mask_pb2.FieldMask = proto.Field( + proto.MESSAGE, + number=2, + message=field_mask_pb2.FieldMask, + ) + + +class DeleteCorpusRequest(proto.Message): + r"""Request to delete a ``Corpus``. + + Attributes: + name (str): + Required. The resource name of the ``Corpus``. Example: + ``corpora/my-corpus-123`` + force (bool): + Optional. If set to true, any ``Document``\ s and objects + related to this ``Corpus`` will also be deleted. + + If false (the default), a ``FAILED_PRECONDITION`` error will + be returned if ``Corpus`` contains any ``Document``\ s. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + force: bool = proto.Field( + proto.BOOL, + number=2, + ) + + +class ListCorporaRequest(proto.Message): + r"""Request for listing ``Corpora``. + + Attributes: + page_size (int): + Optional. The maximum number of ``Corpora`` to return (per + page). The service may return fewer ``Corpora``. + + If unspecified, at most 10 ``Corpora`` will be returned. The + maximum size limit is 20 ``Corpora`` per page. + page_token (str): + Optional. A page token, received from a previous + ``ListCorpora`` call. + + Provide the ``next_page_token`` returned in the response as + an argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListCorpora`` must match the call that provided the page + token. + """ + + page_size: int = proto.Field( + proto.INT32, + number=1, + ) + page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +class ListCorporaResponse(proto.Message): + r"""Response from ``ListCorpora`` containing a paginated list of + ``Corpora``. The results are sorted by ascending + ``corpus.create_time``. + + Attributes: + corpora (MutableSequence[google.ai.generativelanguage_v1alpha.types.Corpus]): + The returned corpora. + next_page_token (str): + A token, which can be sent as ``page_token`` to retrieve the + next page. If this field is omitted, there are no more + pages. + """ + + @property + def raw_page(self): + return self + + corpora: MutableSequence[retriever.Corpus] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=retriever.Corpus, + ) + next_page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +class QueryCorpusRequest(proto.Message): + r"""Request for querying a ``Corpus``. + + Attributes: + name (str): + Required. The name of the ``Corpus`` to query. Example: + ``corpora/my-corpus-123`` + query (str): + Required. Query string to perform semantic + search. + metadata_filters (MutableSequence[google.ai.generativelanguage_v1alpha.types.MetadataFilter]): + Optional. Filter for ``Chunk`` and ``Document`` metadata. + Each ``MetadataFilter`` object should correspond to a unique + key. Multiple ``MetadataFilter`` objects are joined by + logical "AND"s. + + Example query at document level: (year >= 2020 OR year < + 2010) AND (genre = drama OR genre = action) + + ``MetadataFilter`` object list: metadata_filters = [ {key = + "document.custom_metadata.year" conditions = [{int_value = + 2020, operation = GREATER_EQUAL}, {int_value = 2010, + operation = LESS}]}, {key = "document.custom_metadata.year" + conditions = [{int_value = 2020, operation = GREATER_EQUAL}, + {int_value = 2010, operation = LESS}]}, {key = + "document.custom_metadata.genre" conditions = [{string_value + = "drama", operation = EQUAL}, {string_value = "action", + operation = EQUAL}]}] + + Example query at chunk level for a numeric range of values: + (year > 2015 AND year <= 2020) + + ``MetadataFilter`` object list: metadata_filters = [ {key = + "chunk.custom_metadata.year" conditions = [{int_value = + 2015, operation = GREATER}]}, {key = + "chunk.custom_metadata.year" conditions = [{int_value = + 2020, operation = LESS_EQUAL}]}] + + Note: "AND"s for the same key are only supported for numeric + values. String values only support "OR"s for the same key. + results_count (int): + Optional. The maximum number of ``Chunk``\ s to return. The + service may return fewer ``Chunk``\ s. + + If unspecified, at most 10 ``Chunk``\ s will be returned. + The maximum specified result count is 100. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + query: str = proto.Field( + proto.STRING, + number=2, + ) + metadata_filters: MutableSequence[retriever.MetadataFilter] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message=retriever.MetadataFilter, + ) + results_count: int = proto.Field( + proto.INT32, + number=4, + ) + + +class QueryCorpusResponse(proto.Message): + r"""Response from ``QueryCorpus`` containing a list of relevant chunks. + + Attributes: + relevant_chunks (MutableSequence[google.ai.generativelanguage_v1alpha.types.RelevantChunk]): + The relevant chunks. + """ + + relevant_chunks: MutableSequence["RelevantChunk"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="RelevantChunk", + ) + + +class RelevantChunk(proto.Message): + r"""The information for a chunk relevant to a query. + + Attributes: + chunk_relevance_score (float): + ``Chunk`` relevance to the query. + chunk (google.ai.generativelanguage_v1alpha.types.Chunk): + ``Chunk`` associated with the query. + """ + + chunk_relevance_score: float = proto.Field( + proto.FLOAT, + number=1, + ) + chunk: retriever.Chunk = proto.Field( + proto.MESSAGE, + number=2, + message=retriever.Chunk, + ) + + +class CreateDocumentRequest(proto.Message): + r"""Request to create a ``Document``. + + Attributes: + parent (str): + Required. The name of the ``Corpus`` where this ``Document`` + will be created. Example: ``corpora/my-corpus-123`` + document (google.ai.generativelanguage_v1alpha.types.Document): + Required. The ``Document`` to create. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + document: retriever.Document = proto.Field( + proto.MESSAGE, + number=2, + message=retriever.Document, + ) + + +class GetDocumentRequest(proto.Message): + r"""Request for getting information about a specific ``Document``. + + Attributes: + name (str): + Required. The name of the ``Document`` to retrieve. Example: + ``corpora/my-corpus-123/documents/the-doc-abc`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class UpdateDocumentRequest(proto.Message): + r"""Request to update a ``Document``. + + Attributes: + document (google.ai.generativelanguage_v1alpha.types.Document): + Required. The ``Document`` to update. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The list of fields to update. Currently, this only + supports updating ``display_name`` and ``custom_metadata``. + """ + + document: retriever.Document = proto.Field( + proto.MESSAGE, + number=1, + message=retriever.Document, + ) + update_mask: field_mask_pb2.FieldMask = proto.Field( + proto.MESSAGE, + number=2, + message=field_mask_pb2.FieldMask, + ) + + +class DeleteDocumentRequest(proto.Message): + r"""Request to delete a ``Document``. + + Attributes: + name (str): + Required. The resource name of the ``Document`` to delete. + Example: ``corpora/my-corpus-123/documents/the-doc-abc`` + force (bool): + Optional. If set to true, any ``Chunk``\ s and objects + related to this ``Document`` will also be deleted. + + If false (the default), a ``FAILED_PRECONDITION`` error will + be returned if ``Document`` contains any ``Chunk``\ s. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + force: bool = proto.Field( + proto.BOOL, + number=2, + ) + + +class ListDocumentsRequest(proto.Message): + r"""Request for listing ``Document``\ s. + + Attributes: + parent (str): + Required. The name of the ``Corpus`` containing + ``Document``\ s. Example: ``corpora/my-corpus-123`` + page_size (int): + Optional. The maximum number of ``Document``\ s to return + (per page). The service may return fewer ``Document``\ s. + + If unspecified, at most 10 ``Document``\ s will be returned. + The maximum size limit is 20 ``Document``\ s per page. + page_token (str): + Optional. A page token, received from a previous + ``ListDocuments`` call. + + Provide the ``next_page_token`` returned in the response as + an argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListDocuments`` must match the call that provided the page + token. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + page_size: int = proto.Field( + proto.INT32, + number=2, + ) + page_token: str = proto.Field( + proto.STRING, + number=3, + ) + + +class ListDocumentsResponse(proto.Message): + r"""Response from ``ListDocuments`` containing a paginated list of + ``Document``\ s. The ``Document``\ s are sorted by ascending + ``document.create_time``. + + Attributes: + documents (MutableSequence[google.ai.generativelanguage_v1alpha.types.Document]): + The returned ``Document``\ s. + next_page_token (str): + A token, which can be sent as ``page_token`` to retrieve the + next page. If this field is omitted, there are no more + pages. + """ + + @property + def raw_page(self): + return self + + documents: MutableSequence[retriever.Document] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=retriever.Document, + ) + next_page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +class QueryDocumentRequest(proto.Message): + r"""Request for querying a ``Document``. + + Attributes: + name (str): + Required. The name of the ``Document`` to query. Example: + ``corpora/my-corpus-123/documents/the-doc-abc`` + query (str): + Required. Query string to perform semantic + search. + results_count (int): + Optional. The maximum number of ``Chunk``\ s to return. The + service may return fewer ``Chunk``\ s. + + If unspecified, at most 10 ``Chunk``\ s will be returned. + The maximum specified result count is 100. + metadata_filters (MutableSequence[google.ai.generativelanguage_v1alpha.types.MetadataFilter]): + Optional. Filter for ``Chunk`` metadata. Each + ``MetadataFilter`` object should correspond to a unique key. + Multiple ``MetadataFilter`` objects are joined by logical + "AND"s. + + Note: ``Document``-level filtering is not supported for this + request because a ``Document`` name is already specified. + + Example query: (year >= 2020 OR year < 2010) AND (genre = + drama OR genre = action) + + ``MetadataFilter`` object list: metadata_filters = [ {key = + "chunk.custom_metadata.year" conditions = [{int_value = + 2020, operation = GREATER_EQUAL}, {int_value = 2010, + operation = LESS}}, {key = "chunk.custom_metadata.genre" + conditions = [{string_value = "drama", operation = EQUAL}, + {string_value = "action", operation = EQUAL}}] + + Example query for a numeric range of values: (year > 2015 + AND year <= 2020) + + ``MetadataFilter`` object list: metadata_filters = [ {key = + "chunk.custom_metadata.year" conditions = [{int_value = + 2015, operation = GREATER}]}, {key = + "chunk.custom_metadata.year" conditions = [{int_value = + 2020, operation = LESS_EQUAL}]}] + + Note: "AND"s for the same key are only supported for numeric + values. String values only support "OR"s for the same key. + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + query: str = proto.Field( + proto.STRING, + number=2, + ) + results_count: int = proto.Field( + proto.INT32, + number=3, + ) + metadata_filters: MutableSequence[retriever.MetadataFilter] = proto.RepeatedField( + proto.MESSAGE, + number=4, + message=retriever.MetadataFilter, + ) + + +class QueryDocumentResponse(proto.Message): + r"""Response from ``QueryDocument`` containing a list of relevant + chunks. + + Attributes: + relevant_chunks (MutableSequence[google.ai.generativelanguage_v1alpha.types.RelevantChunk]): + The returned relevant chunks. + """ + + relevant_chunks: MutableSequence["RelevantChunk"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="RelevantChunk", + ) + + +class CreateChunkRequest(proto.Message): + r"""Request to create a ``Chunk``. + + Attributes: + parent (str): + Required. The name of the ``Document`` where this ``Chunk`` + will be created. Example: + ``corpora/my-corpus-123/documents/the-doc-abc`` + chunk (google.ai.generativelanguage_v1alpha.types.Chunk): + Required. The ``Chunk`` to create. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + chunk: retriever.Chunk = proto.Field( + proto.MESSAGE, + number=2, + message=retriever.Chunk, + ) + + +class BatchCreateChunksRequest(proto.Message): + r"""Request to batch create ``Chunk``\ s. + + Attributes: + parent (str): + Optional. The name of the ``Document`` where this batch of + ``Chunk``\ s will be created. The parent field in every + ``CreateChunkRequest`` must match this value. Example: + ``corpora/my-corpus-123/documents/the-doc-abc`` + requests (MutableSequence[google.ai.generativelanguage_v1alpha.types.CreateChunkRequest]): + Required. The request messages specifying the ``Chunk``\ s + to create. A maximum of 100 ``Chunk``\ s can be created in a + batch. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + requests: MutableSequence["CreateChunkRequest"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="CreateChunkRequest", + ) + + +class BatchCreateChunksResponse(proto.Message): + r"""Response from ``BatchCreateChunks`` containing a list of created + ``Chunk``\ s. + + Attributes: + chunks (MutableSequence[google.ai.generativelanguage_v1alpha.types.Chunk]): + ``Chunk``\ s created. + """ + + chunks: MutableSequence[retriever.Chunk] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=retriever.Chunk, + ) + + +class GetChunkRequest(proto.Message): + r"""Request for getting information about a specific ``Chunk``. + + Attributes: + name (str): + Required. The name of the ``Chunk`` to retrieve. Example: + ``corpora/my-corpus-123/documents/the-doc-abc/chunks/some-chunk`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class UpdateChunkRequest(proto.Message): + r"""Request to update a ``Chunk``. + + Attributes: + chunk (google.ai.generativelanguage_v1alpha.types.Chunk): + Required. The ``Chunk`` to update. + update_mask (google.protobuf.field_mask_pb2.FieldMask): + Required. The list of fields to update. Currently, this only + supports updating ``custom_metadata`` and ``data``. + """ + + chunk: retriever.Chunk = proto.Field( + proto.MESSAGE, + number=1, + message=retriever.Chunk, + ) + update_mask: field_mask_pb2.FieldMask = proto.Field( + proto.MESSAGE, + number=2, + message=field_mask_pb2.FieldMask, + ) + + +class BatchUpdateChunksRequest(proto.Message): + r"""Request to batch update ``Chunk``\ s. + + Attributes: + parent (str): + Optional. The name of the ``Document`` containing the + ``Chunk``\ s to update. The parent field in every + ``UpdateChunkRequest`` must match this value. Example: + ``corpora/my-corpus-123/documents/the-doc-abc`` + requests (MutableSequence[google.ai.generativelanguage_v1alpha.types.UpdateChunkRequest]): + Required. The request messages specifying the ``Chunk``\ s + to update. A maximum of 100 ``Chunk``\ s can be updated in a + batch. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + requests: MutableSequence["UpdateChunkRequest"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="UpdateChunkRequest", + ) + + +class BatchUpdateChunksResponse(proto.Message): + r"""Response from ``BatchUpdateChunks`` containing a list of updated + ``Chunk``\ s. + + Attributes: + chunks (MutableSequence[google.ai.generativelanguage_v1alpha.types.Chunk]): + ``Chunk``\ s updated. + """ + + chunks: MutableSequence[retriever.Chunk] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=retriever.Chunk, + ) + + +class DeleteChunkRequest(proto.Message): + r"""Request to delete a ``Chunk``. + + Attributes: + name (str): + Required. The resource name of the ``Chunk`` to delete. + Example: + ``corpora/my-corpus-123/documents/the-doc-abc/chunks/some-chunk`` + """ + + name: str = proto.Field( + proto.STRING, + number=1, + ) + + +class BatchDeleteChunksRequest(proto.Message): + r"""Request to batch delete ``Chunk``\ s. + + Attributes: + parent (str): + Optional. The name of the ``Document`` containing the + ``Chunk``\ s to delete. The parent field in every + ``DeleteChunkRequest`` must match this value. Example: + ``corpora/my-corpus-123/documents/the-doc-abc`` + requests (MutableSequence[google.ai.generativelanguage_v1alpha.types.DeleteChunkRequest]): + Required. The request messages specifying the ``Chunk``\ s + to delete. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + requests: MutableSequence["DeleteChunkRequest"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="DeleteChunkRequest", + ) + + +class ListChunksRequest(proto.Message): + r"""Request for listing ``Chunk``\ s. + + Attributes: + parent (str): + Required. The name of the ``Document`` containing + ``Chunk``\ s. Example: + ``corpora/my-corpus-123/documents/the-doc-abc`` + page_size (int): + Optional. The maximum number of ``Chunk``\ s to return (per + page). The service may return fewer ``Chunk``\ s. + + If unspecified, at most 10 ``Chunk``\ s will be returned. + The maximum size limit is 100 ``Chunk``\ s per page. + page_token (str): + Optional. A page token, received from a previous + ``ListChunks`` call. + + Provide the ``next_page_token`` returned in the response as + an argument to the next request to retrieve the next page. + + When paginating, all other parameters provided to + ``ListChunks`` must match the call that provided the page + token. + """ + + parent: str = proto.Field( + proto.STRING, + number=1, + ) + page_size: int = proto.Field( + proto.INT32, + number=2, + ) + page_token: str = proto.Field( + proto.STRING, + number=3, + ) + + +class ListChunksResponse(proto.Message): + r"""Response from ``ListChunks`` containing a paginated list of + ``Chunk``\ s. The ``Chunk``\ s are sorted by ascending + ``chunk.create_time``. + + Attributes: + chunks (MutableSequence[google.ai.generativelanguage_v1alpha.types.Chunk]): + The returned ``Chunk``\ s. + next_page_token (str): + A token, which can be sent as ``page_token`` to retrieve the + next page. If this field is omitted, there are no more + pages. + """ + + @property + def raw_page(self): + return self + + chunks: MutableSequence[retriever.Chunk] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message=retriever.Chunk, + ) + next_page_token: str = proto.Field( + proto.STRING, + number=2, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/safety.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/safety.py new file mode 100644 index 000000000000..3fbad41feeff --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/safety.py @@ -0,0 +1,276 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "HarmCategory", + "ContentFilter", + "SafetyFeedback", + "SafetyRating", + "SafetySetting", + }, +) + + +class HarmCategory(proto.Enum): + r"""The category of a rating. + + These categories cover various kinds of harms that developers + may wish to adjust. + + Values: + HARM_CATEGORY_UNSPECIFIED (0): + Category is unspecified. + HARM_CATEGORY_DEROGATORY (1): + **PaLM** - Negative or harmful comments targeting identity + and/or protected attribute. + HARM_CATEGORY_TOXICITY (2): + **PaLM** - Content that is rude, disrespectful, or profane. + HARM_CATEGORY_VIOLENCE (3): + **PaLM** - Describes scenarios depicting violence against an + individual or group, or general descriptions of gore. + HARM_CATEGORY_SEXUAL (4): + **PaLM** - Contains references to sexual acts or other lewd + content. + HARM_CATEGORY_MEDICAL (5): + **PaLM** - Promotes unchecked medical advice. + HARM_CATEGORY_DANGEROUS (6): + **PaLM** - Dangerous content that promotes, facilitates, or + encourages harmful acts. + HARM_CATEGORY_HARASSMENT (7): + **Gemini** - Harassment content. + HARM_CATEGORY_HATE_SPEECH (8): + **Gemini** - Hate speech and content. + HARM_CATEGORY_SEXUALLY_EXPLICIT (9): + **Gemini** - Sexually explicit content. + HARM_CATEGORY_DANGEROUS_CONTENT (10): + **Gemini** - Dangerous content. + HARM_CATEGORY_CIVIC_INTEGRITY (11): + **Gemini** - Content that may be used to harm civic + integrity. + """ + HARM_CATEGORY_UNSPECIFIED = 0 + HARM_CATEGORY_DEROGATORY = 1 + HARM_CATEGORY_TOXICITY = 2 + HARM_CATEGORY_VIOLENCE = 3 + HARM_CATEGORY_SEXUAL = 4 + HARM_CATEGORY_MEDICAL = 5 + HARM_CATEGORY_DANGEROUS = 6 + HARM_CATEGORY_HARASSMENT = 7 + HARM_CATEGORY_HATE_SPEECH = 8 + HARM_CATEGORY_SEXUALLY_EXPLICIT = 9 + HARM_CATEGORY_DANGEROUS_CONTENT = 10 + HARM_CATEGORY_CIVIC_INTEGRITY = 11 + + +class ContentFilter(proto.Message): + r"""Content filtering metadata associated with processing a + single request. + ContentFilter contains a reason and an optional supporting + string. The reason may be unspecified. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + reason (google.ai.generativelanguage_v1alpha.types.ContentFilter.BlockedReason): + The reason content was blocked during request + processing. + message (str): + A string that describes the filtering + behavior in more detail. + + This field is a member of `oneof`_ ``_message``. + """ + + class BlockedReason(proto.Enum): + r"""A list of reasons why content may have been blocked. + + Values: + BLOCKED_REASON_UNSPECIFIED (0): + A blocked reason was not specified. + SAFETY (1): + Content was blocked by safety settings. + OTHER (2): + Content was blocked, but the reason is + uncategorized. + """ + BLOCKED_REASON_UNSPECIFIED = 0 + SAFETY = 1 + OTHER = 2 + + reason: BlockedReason = proto.Field( + proto.ENUM, + number=1, + enum=BlockedReason, + ) + message: str = proto.Field( + proto.STRING, + number=2, + optional=True, + ) + + +class SafetyFeedback(proto.Message): + r"""Safety feedback for an entire request. + + This field is populated if content in the input and/or response + is blocked due to safety settings. SafetyFeedback may not exist + for every HarmCategory. Each SafetyFeedback will return the + safety settings used by the request as well as the lowest + HarmProbability that should be allowed in order to return a + result. + + Attributes: + rating (google.ai.generativelanguage_v1alpha.types.SafetyRating): + Safety rating evaluated from content. + setting (google.ai.generativelanguage_v1alpha.types.SafetySetting): + Safety settings applied to the request. + """ + + rating: "SafetyRating" = proto.Field( + proto.MESSAGE, + number=1, + message="SafetyRating", + ) + setting: "SafetySetting" = proto.Field( + proto.MESSAGE, + number=2, + message="SafetySetting", + ) + + +class SafetyRating(proto.Message): + r"""Safety rating for a piece of content. + + The safety rating contains the category of harm and the harm + probability level in that category for a piece of content. + Content is classified for safety across a number of harm + categories and the probability of the harm classification is + included here. + + Attributes: + category (google.ai.generativelanguage_v1alpha.types.HarmCategory): + Required. The category for this rating. + probability (google.ai.generativelanguage_v1alpha.types.SafetyRating.HarmProbability): + Required. The probability of harm for this + content. + blocked (bool): + Was this content blocked because of this + rating? + """ + + class HarmProbability(proto.Enum): + r"""The probability that a piece of content is harmful. + + The classification system gives the probability of the content + being unsafe. This does not indicate the severity of harm for a + piece of content. + + Values: + HARM_PROBABILITY_UNSPECIFIED (0): + Probability is unspecified. + NEGLIGIBLE (1): + Content has a negligible chance of being + unsafe. + LOW (2): + Content has a low chance of being unsafe. + MEDIUM (3): + Content has a medium chance of being unsafe. + HIGH (4): + Content has a high chance of being unsafe. + """ + HARM_PROBABILITY_UNSPECIFIED = 0 + NEGLIGIBLE = 1 + LOW = 2 + MEDIUM = 3 + HIGH = 4 + + category: "HarmCategory" = proto.Field( + proto.ENUM, + number=3, + enum="HarmCategory", + ) + probability: HarmProbability = proto.Field( + proto.ENUM, + number=4, + enum=HarmProbability, + ) + blocked: bool = proto.Field( + proto.BOOL, + number=5, + ) + + +class SafetySetting(proto.Message): + r"""Safety setting, affecting the safety-blocking behavior. + + Passing a safety setting for a category changes the allowed + probability that content is blocked. + + Attributes: + category (google.ai.generativelanguage_v1alpha.types.HarmCategory): + Required. The category for this setting. + threshold (google.ai.generativelanguage_v1alpha.types.SafetySetting.HarmBlockThreshold): + Required. Controls the probability threshold + at which harm is blocked. + """ + + class HarmBlockThreshold(proto.Enum): + r"""Block at and beyond a specified harm probability. + + Values: + HARM_BLOCK_THRESHOLD_UNSPECIFIED (0): + Threshold is unspecified. + BLOCK_LOW_AND_ABOVE (1): + Content with NEGLIGIBLE will be allowed. + BLOCK_MEDIUM_AND_ABOVE (2): + Content with NEGLIGIBLE and LOW will be + allowed. + BLOCK_ONLY_HIGH (3): + Content with NEGLIGIBLE, LOW, and MEDIUM will + be allowed. + BLOCK_NONE (4): + All content will be allowed. + OFF (5): + Turn off the safety filter. + """ + HARM_BLOCK_THRESHOLD_UNSPECIFIED = 0 + BLOCK_LOW_AND_ABOVE = 1 + BLOCK_MEDIUM_AND_ABOVE = 2 + BLOCK_ONLY_HIGH = 3 + BLOCK_NONE = 4 + OFF = 5 + + category: "HarmCategory" = proto.Field( + proto.ENUM, + number=3, + enum="HarmCategory", + ) + threshold: HarmBlockThreshold = proto.Field( + proto.ENUM, + number=4, + enum=HarmBlockThreshold, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/text_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/text_service.py new file mode 100644 index 000000000000..9daa38137d05 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/text_service.py @@ -0,0 +1,441 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +import proto # type: ignore + +from google.ai.generativelanguage_v1alpha.types import citation, safety + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "GenerateTextRequest", + "GenerateTextResponse", + "TextPrompt", + "TextCompletion", + "EmbedTextRequest", + "EmbedTextResponse", + "BatchEmbedTextRequest", + "BatchEmbedTextResponse", + "Embedding", + "CountTextTokensRequest", + "CountTextTokensResponse", + }, +) + + +class GenerateTextRequest(proto.Message): + r"""Request to generate a text completion response from the + model. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + model (str): + Required. The name of the ``Model`` or ``TunedModel`` to use + for generating the completion. Examples: + models/text-bison-001 tunedModels/sentence-translator-u3b7m + prompt (google.ai.generativelanguage_v1alpha.types.TextPrompt): + Required. The free-form input text given to + the model as a prompt. + Given a prompt, the model will generate a + TextCompletion response it predicts as the + completion of the input text. + temperature (float): + Optional. Controls the randomness of the output. Note: The + default value varies by model, see the ``Model.temperature`` + attribute of the ``Model`` returned the ``getModel`` + function. + + Values can range from [0.0,1.0], inclusive. A value closer + to 1.0 will produce responses that are more varied and + creative, while a value closer to 0.0 will typically result + in more straightforward responses from the model. + + This field is a member of `oneof`_ ``_temperature``. + candidate_count (int): + Optional. Number of generated responses to return. + + This value must be between [1, 8], inclusive. If unset, this + will default to 1. + + This field is a member of `oneof`_ ``_candidate_count``. + max_output_tokens (int): + Optional. The maximum number of tokens to include in a + candidate. + + If unset, this will default to output_token_limit specified + in the ``Model`` specification. + + This field is a member of `oneof`_ ``_max_output_tokens``. + top_p (float): + Optional. The maximum cumulative probability of tokens to + consider when sampling. + + The model uses combined Top-k and nucleus sampling. + + Tokens are sorted based on their assigned probabilities so + that only the most likely tokens are considered. Top-k + sampling directly limits the maximum number of tokens to + consider, while Nucleus sampling limits number of tokens + based on the cumulative probability. + + Note: The default value varies by model, see the + ``Model.top_p`` attribute of the ``Model`` returned the + ``getModel`` function. + + This field is a member of `oneof`_ ``_top_p``. + top_k (int): + Optional. The maximum number of tokens to consider when + sampling. + + The model uses combined Top-k and nucleus sampling. + + Top-k sampling considers the set of ``top_k`` most probable + tokens. Defaults to 40. + + Note: The default value varies by model, see the + ``Model.top_k`` attribute of the ``Model`` returned the + ``getModel`` function. + + This field is a member of `oneof`_ ``_top_k``. + safety_settings (MutableSequence[google.ai.generativelanguage_v1alpha.types.SafetySetting]): + Optional. A list of unique ``SafetySetting`` instances for + blocking unsafe content. + + that will be enforced on the ``GenerateTextRequest.prompt`` + and ``GenerateTextResponse.candidates``. There should not be + more than one setting for each ``SafetyCategory`` type. The + API will block any prompts and responses that fail to meet + the thresholds set by these settings. This list overrides + the default settings for each ``SafetyCategory`` specified + in the safety_settings. If there is no ``SafetySetting`` for + a given ``SafetyCategory`` provided in the list, the API + will use the default safety setting for that category. Harm + categories HARM_CATEGORY_DEROGATORY, HARM_CATEGORY_TOXICITY, + HARM_CATEGORY_VIOLENCE, HARM_CATEGORY_SEXUAL, + HARM_CATEGORY_MEDICAL, HARM_CATEGORY_DANGEROUS are supported + in text service. + stop_sequences (MutableSequence[str]): + The set of character sequences (up to 5) that + will stop output generation. If specified, the + API will stop at the first appearance of a stop + sequence. The stop sequence will not be included + as part of the response. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + prompt: "TextPrompt" = proto.Field( + proto.MESSAGE, + number=2, + message="TextPrompt", + ) + temperature: float = proto.Field( + proto.FLOAT, + number=3, + optional=True, + ) + candidate_count: int = proto.Field( + proto.INT32, + number=4, + optional=True, + ) + max_output_tokens: int = proto.Field( + proto.INT32, + number=5, + optional=True, + ) + top_p: float = proto.Field( + proto.FLOAT, + number=6, + optional=True, + ) + top_k: int = proto.Field( + proto.INT32, + number=7, + optional=True, + ) + safety_settings: MutableSequence[safety.SafetySetting] = proto.RepeatedField( + proto.MESSAGE, + number=8, + message=safety.SafetySetting, + ) + stop_sequences: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=9, + ) + + +class GenerateTextResponse(proto.Message): + r"""The response from the model, including candidate completions. + + Attributes: + candidates (MutableSequence[google.ai.generativelanguage_v1alpha.types.TextCompletion]): + Candidate responses from the model. + filters (MutableSequence[google.ai.generativelanguage_v1alpha.types.ContentFilter]): + A set of content filtering metadata for the prompt and + response text. + + This indicates which ``SafetyCategory``\ (s) blocked a + candidate from this response, the lowest ``HarmProbability`` + that triggered a block, and the HarmThreshold setting for + that category. This indicates the smallest change to the + ``SafetySettings`` that would be necessary to unblock at + least 1 response. + + The blocking is configured by the ``SafetySettings`` in the + request (or the default ``SafetySettings`` of the API). + safety_feedback (MutableSequence[google.ai.generativelanguage_v1alpha.types.SafetyFeedback]): + Returns any safety feedback related to + content filtering. + """ + + candidates: MutableSequence["TextCompletion"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="TextCompletion", + ) + filters: MutableSequence[safety.ContentFilter] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message=safety.ContentFilter, + ) + safety_feedback: MutableSequence[safety.SafetyFeedback] = proto.RepeatedField( + proto.MESSAGE, + number=4, + message=safety.SafetyFeedback, + ) + + +class TextPrompt(proto.Message): + r"""Text given to the model as a prompt. + + The Model will use this TextPrompt to Generate a text + completion. + + Attributes: + text (str): + Required. The prompt text. + """ + + text: str = proto.Field( + proto.STRING, + number=1, + ) + + +class TextCompletion(proto.Message): + r"""Output text returned from a model. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + output (str): + Output only. The generated text returned from + the model. + safety_ratings (MutableSequence[google.ai.generativelanguage_v1alpha.types.SafetyRating]): + Ratings for the safety of a response. + + There is at most one rating per category. + citation_metadata (google.ai.generativelanguage_v1alpha.types.CitationMetadata): + Output only. Citation information for model-generated + ``output`` in this ``TextCompletion``. + + This field may be populated with attribution information for + any text included in the ``output``. + + This field is a member of `oneof`_ ``_citation_metadata``. + """ + + output: str = proto.Field( + proto.STRING, + number=1, + ) + safety_ratings: MutableSequence[safety.SafetyRating] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message=safety.SafetyRating, + ) + citation_metadata: citation.CitationMetadata = proto.Field( + proto.MESSAGE, + number=3, + optional=True, + message=citation.CitationMetadata, + ) + + +class EmbedTextRequest(proto.Message): + r"""Request to get a text embedding from the model. + + Attributes: + model (str): + Required. The model name to use with the + format model=models/{model}. + text (str): + Optional. The free-form input text that the + model will turn into an embedding. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + text: str = proto.Field( + proto.STRING, + number=2, + ) + + +class EmbedTextResponse(proto.Message): + r"""The response to a EmbedTextRequest. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + embedding (google.ai.generativelanguage_v1alpha.types.Embedding): + Output only. The embedding generated from the + input text. + + This field is a member of `oneof`_ ``_embedding``. + """ + + embedding: "Embedding" = proto.Field( + proto.MESSAGE, + number=1, + optional=True, + message="Embedding", + ) + + +class BatchEmbedTextRequest(proto.Message): + r"""Batch request to get a text embedding from the model. + + Attributes: + model (str): + Required. The name of the ``Model`` to use for generating + the embedding. Examples: models/embedding-gecko-001 + texts (MutableSequence[str]): + Optional. The free-form input texts that the + model will turn into an embedding. The current + limit is 100 texts, over which an error will be + thrown. + requests (MutableSequence[google.ai.generativelanguage_v1alpha.types.EmbedTextRequest]): + Optional. Embed requests for the batch. Only one of + ``texts`` or ``requests`` can be set. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + texts: MutableSequence[str] = proto.RepeatedField( + proto.STRING, + number=2, + ) + requests: MutableSequence["EmbedTextRequest"] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="EmbedTextRequest", + ) + + +class BatchEmbedTextResponse(proto.Message): + r"""The response to a EmbedTextRequest. + + Attributes: + embeddings (MutableSequence[google.ai.generativelanguage_v1alpha.types.Embedding]): + Output only. The embeddings generated from + the input text. + """ + + embeddings: MutableSequence["Embedding"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="Embedding", + ) + + +class Embedding(proto.Message): + r"""A list of floats representing the embedding. + + Attributes: + value (MutableSequence[float]): + The embedding values. + """ + + value: MutableSequence[float] = proto.RepeatedField( + proto.FLOAT, + number=1, + ) + + +class CountTextTokensRequest(proto.Message): + r"""Counts the number of tokens in the ``prompt`` sent to a model. + + Models may tokenize text differently, so each model may return a + different ``token_count``. + + Attributes: + model (str): + Required. The model's resource name. This serves as an ID + for the Model to use. + + This name should match a model name returned by the + ``ListModels`` method. + + Format: ``models/{model}`` + prompt (google.ai.generativelanguage_v1alpha.types.TextPrompt): + Required. The free-form input text given to + the model as a prompt. + """ + + model: str = proto.Field( + proto.STRING, + number=1, + ) + prompt: "TextPrompt" = proto.Field( + proto.MESSAGE, + number=2, + message="TextPrompt", + ) + + +class CountTextTokensResponse(proto.Message): + r"""A response from ``CountTextTokens``. + + It returns the model's ``token_count`` for the ``prompt``. + + Attributes: + token_count (int): + The number of tokens that the ``model`` tokenizes the + ``prompt`` into. + + Always non-negative. + """ + + token_count: int = proto.Field( + proto.INT32, + number=1, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/tuned_model.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/tuned_model.py new file mode 100644 index 000000000000..d714c1f8c4d5 --- /dev/null +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1alpha/types/tuned_model.py @@ -0,0 +1,542 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from typing import MutableMapping, MutableSequence + +from google.protobuf import timestamp_pb2 # type: ignore +import proto # type: ignore + +__protobuf__ = proto.module( + package="google.ai.generativelanguage.v1alpha", + manifest={ + "TunedModel", + "TunedModelSource", + "TuningTask", + "Hyperparameters", + "Dataset", + "TuningExamples", + "TuningPart", + "TuningContent", + "TuningMultiturnExample", + "TuningExample", + "TuningSnapshot", + }, +) + + +class TunedModel(proto.Message): + r"""A fine-tuned model created using + ModelService.CreateTunedModel. + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + tuned_model_source (google.ai.generativelanguage_v1alpha.types.TunedModelSource): + Optional. TunedModel to use as the starting + point for training the new model. + + This field is a member of `oneof`_ ``source_model``. + base_model (str): + Immutable. The name of the ``Model`` to tune. Example: + ``models/gemini-1.5-flash-001`` + + This field is a member of `oneof`_ ``source_model``. + name (str): + Output only. The tuned model name. A unique name will be + generated on create. Example: ``tunedModels/az2mb0bpw6i`` If + display_name is set on create, the id portion of the name + will be set by concatenating the words of the display_name + with hyphens and adding a random portion for uniqueness. + + Example: + + - display_name = ``Sentence Translator`` + - name = ``tunedModels/sentence-translator-u3b7m`` + display_name (str): + Optional. The name to display for this model + in user interfaces. The display name must be up + to 40 characters including spaces. + description (str): + Optional. A short description of this model. + temperature (float): + Optional. Controls the randomness of the output. + + Values can range over ``[0.0,1.0]``, inclusive. A value + closer to ``1.0`` will produce responses that are more + varied, while a value closer to ``0.0`` will typically + result in less surprising responses from the model. + + This value specifies default to be the one used by the base + model while creating the model. + + This field is a member of `oneof`_ ``_temperature``. + top_p (float): + Optional. For Nucleus sampling. + + Nucleus sampling considers the smallest set of tokens whose + probability sum is at least ``top_p``. + + This value specifies default to be the one used by the base + model while creating the model. + + This field is a member of `oneof`_ ``_top_p``. + top_k (int): + Optional. For Top-k sampling. + + Top-k sampling considers the set of ``top_k`` most probable + tokens. This value specifies default to be used by the + backend while making the call to the model. + + This value specifies default to be the one used by the base + model while creating the model. + + This field is a member of `oneof`_ ``_top_k``. + state (google.ai.generativelanguage_v1alpha.types.TunedModel.State): + Output only. The state of the tuned model. + create_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The timestamp when this model + was created. + update_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The timestamp when this model + was updated. + tuning_task (google.ai.generativelanguage_v1alpha.types.TuningTask): + Required. The tuning task that creates the + tuned model. + reader_project_numbers (MutableSequence[int]): + Optional. List of project numbers that have + read access to the tuned model. + """ + + class State(proto.Enum): + r"""The state of the tuned model. + + Values: + STATE_UNSPECIFIED (0): + The default value. This value is unused. + CREATING (1): + The model is being created. + ACTIVE (2): + The model is ready to be used. + FAILED (3): + The model failed to be created. + """ + STATE_UNSPECIFIED = 0 + CREATING = 1 + ACTIVE = 2 + FAILED = 3 + + tuned_model_source: "TunedModelSource" = proto.Field( + proto.MESSAGE, + number=3, + oneof="source_model", + message="TunedModelSource", + ) + base_model: str = proto.Field( + proto.STRING, + number=4, + oneof="source_model", + ) + name: str = proto.Field( + proto.STRING, + number=1, + ) + display_name: str = proto.Field( + proto.STRING, + number=5, + ) + description: str = proto.Field( + proto.STRING, + number=6, + ) + temperature: float = proto.Field( + proto.FLOAT, + number=11, + optional=True, + ) + top_p: float = proto.Field( + proto.FLOAT, + number=12, + optional=True, + ) + top_k: int = proto.Field( + proto.INT32, + number=13, + optional=True, + ) + state: State = proto.Field( + proto.ENUM, + number=7, + enum=State, + ) + create_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=8, + message=timestamp_pb2.Timestamp, + ) + update_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=9, + message=timestamp_pb2.Timestamp, + ) + tuning_task: "TuningTask" = proto.Field( + proto.MESSAGE, + number=10, + message="TuningTask", + ) + reader_project_numbers: MutableSequence[int] = proto.RepeatedField( + proto.INT64, + number=14, + ) + + +class TunedModelSource(proto.Message): + r"""Tuned model as a source for training a new model. + + Attributes: + tuned_model (str): + Immutable. The name of the ``TunedModel`` to use as the + starting point for training the new model. Example: + ``tunedModels/my-tuned-model`` + base_model (str): + Output only. The name of the base ``Model`` this + ``TunedModel`` was tuned from. Example: + ``models/gemini-1.5-flash-001`` + """ + + tuned_model: str = proto.Field( + proto.STRING, + number=1, + ) + base_model: str = proto.Field( + proto.STRING, + number=2, + ) + + +class TuningTask(proto.Message): + r"""Tuning tasks that create tuned models. + + Attributes: + start_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The timestamp when tuning this + model started. + complete_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The timestamp when tuning this + model completed. + snapshots (MutableSequence[google.ai.generativelanguage_v1alpha.types.TuningSnapshot]): + Output only. Metrics collected during tuning. + training_data (google.ai.generativelanguage_v1alpha.types.Dataset): + Required. Input only. Immutable. The model + training data. + hyperparameters (google.ai.generativelanguage_v1alpha.types.Hyperparameters): + Immutable. Hyperparameters controlling the + tuning process. If not provided, default values + will be used. + """ + + start_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=1, + message=timestamp_pb2.Timestamp, + ) + complete_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=2, + message=timestamp_pb2.Timestamp, + ) + snapshots: MutableSequence["TuningSnapshot"] = proto.RepeatedField( + proto.MESSAGE, + number=3, + message="TuningSnapshot", + ) + training_data: "Dataset" = proto.Field( + proto.MESSAGE, + number=4, + message="Dataset", + ) + hyperparameters: "Hyperparameters" = proto.Field( + proto.MESSAGE, + number=5, + message="Hyperparameters", + ) + + +class Hyperparameters(proto.Message): + r"""Hyperparameters controlling the tuning process. Read more at + https://ai.google.dev/docs/model_tuning_guidance + + This message has `oneof`_ fields (mutually exclusive fields). + For each oneof, at most one member field can be set at the same time. + Setting any member of the oneof automatically clears all other + members. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + learning_rate (float): + Optional. Immutable. The learning rate + hyperparameter for tuning. If not set, a default + of 0.001 or 0.0002 will be calculated based on + the number of training examples. + + This field is a member of `oneof`_ ``learning_rate_option``. + learning_rate_multiplier (float): + Optional. Immutable. The learning rate multiplier is used to + calculate a final learning_rate based on the default + (recommended) value. Actual learning rate := + learning_rate_multiplier \* default learning rate Default + learning rate is dependent on base model and dataset size. + If not set, a default of 1.0 will be used. + + This field is a member of `oneof`_ ``learning_rate_option``. + epoch_count (int): + Immutable. The number of training epochs. An + epoch is one pass through the training data. If + not set, a default of 5 will be used. + + This field is a member of `oneof`_ ``_epoch_count``. + batch_size (int): + Immutable. The batch size hyperparameter for + tuning. If not set, a default of 4 or 16 will be + used based on the number of training examples. + + This field is a member of `oneof`_ ``_batch_size``. + """ + + learning_rate: float = proto.Field( + proto.FLOAT, + number=16, + oneof="learning_rate_option", + ) + learning_rate_multiplier: float = proto.Field( + proto.FLOAT, + number=17, + oneof="learning_rate_option", + ) + epoch_count: int = proto.Field( + proto.INT32, + number=14, + optional=True, + ) + batch_size: int = proto.Field( + proto.INT32, + number=15, + optional=True, + ) + + +class Dataset(proto.Message): + r"""Dataset for training or validation. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + examples (google.ai.generativelanguage_v1alpha.types.TuningExamples): + Optional. Inline examples with simple + input/output text. + + This field is a member of `oneof`_ ``dataset``. + """ + + examples: "TuningExamples" = proto.Field( + proto.MESSAGE, + number=1, + oneof="dataset", + message="TuningExamples", + ) + + +class TuningExamples(proto.Message): + r"""A set of tuning examples. Can be training or validation data. + + Attributes: + examples (MutableSequence[google.ai.generativelanguage_v1alpha.types.TuningExample]): + The examples. Example input can be for text + or discuss, but all examples in a set must be of + the same type. + multiturn_examples (MutableSequence[google.ai.generativelanguage_v1alpha.types.TuningMultiturnExample]): + Content examples. For multiturn + conversations. + """ + + examples: MutableSequence["TuningExample"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="TuningExample", + ) + multiturn_examples: MutableSequence["TuningMultiturnExample"] = proto.RepeatedField( + proto.MESSAGE, + number=2, + message="TuningMultiturnExample", + ) + + +class TuningPart(proto.Message): + r"""A datatype containing data that is part of a multi-part + ``TuningContent`` message. + + This is a subset of the Part used for model inference, with limited + type support. + + A ``Part`` consists of data which has an associated datatype. A + ``Part`` can only contain one of the accepted types in + ``Part.data``. + + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + text (str): + Inline text. + + This field is a member of `oneof`_ ``data``. + """ + + text: str = proto.Field( + proto.STRING, + number=2, + oneof="data", + ) + + +class TuningContent(proto.Message): + r"""The structured datatype containing multi-part content of an example + message. + + This is a subset of the Content proto used during model inference + with limited type support. A ``Content`` includes a ``role`` field + designating the producer of the ``Content`` and a ``parts`` field + containing multi-part data that contains the content of the message + turn. + + Attributes: + parts (MutableSequence[google.ai.generativelanguage_v1alpha.types.TuningPart]): + Ordered ``Parts`` that constitute a single message. Parts + may have different MIME types. + role (str): + Optional. The producer of the content. Must + be either 'user' or 'model'. + Useful to set for multi-turn conversations, + otherwise can be left blank or unset. + """ + + parts: MutableSequence["TuningPart"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="TuningPart", + ) + role: str = proto.Field( + proto.STRING, + number=2, + ) + + +class TuningMultiturnExample(proto.Message): + r"""A tuning example with multiturn input. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + system_instruction (google.ai.generativelanguage_v1alpha.types.TuningContent): + Optional. Developer set system instructions. + Currently, text only. + + This field is a member of `oneof`_ ``_system_instruction``. + contents (MutableSequence[google.ai.generativelanguage_v1alpha.types.TuningContent]): + Each Content represents a turn in the + conversation. + """ + + system_instruction: "TuningContent" = proto.Field( + proto.MESSAGE, + number=8, + optional=True, + message="TuningContent", + ) + contents: MutableSequence["TuningContent"] = proto.RepeatedField( + proto.MESSAGE, + number=1, + message="TuningContent", + ) + + +class TuningExample(proto.Message): + r"""A single example for tuning. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + text_input (str): + Optional. Text model input. + + This field is a member of `oneof`_ ``model_input``. + output (str): + Required. The expected model output. + """ + + text_input: str = proto.Field( + proto.STRING, + number=1, + oneof="model_input", + ) + output: str = proto.Field( + proto.STRING, + number=3, + ) + + +class TuningSnapshot(proto.Message): + r"""Record for a single tuning step. + + Attributes: + step (int): + Output only. The tuning step. + epoch (int): + Output only. The epoch this step was part of. + mean_loss (float): + Output only. The mean loss of the training + examples for this step. + compute_time (google.protobuf.timestamp_pb2.Timestamp): + Output only. The timestamp when this metric + was computed. + """ + + step: int = proto.Field( + proto.INT32, + number=1, + ) + epoch: int = proto.Field( + proto.INT32, + number=2, + ) + mean_loss: float = proto.Field( + proto.FLOAT, + number=3, + ) + compute_time: timestamp_pb2.Timestamp = proto.Field( + proto.MESSAGE, + number=4, + message=timestamp_pb2.Timestamp, + ) + + +__all__ = tuple(sorted(__protobuf__.manifest)) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/__init__.py index 73da8c53fefc..9540c3ab3502 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/__init__.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/__init__.py @@ -108,11 +108,14 @@ GroundingMetadata, GroundingSupport, LogprobsResult, + PrebuiltVoiceConfig, RetrievalMetadata, SearchEntryPoint, Segment, SemanticRetrieverConfig, + SpeechConfig, TaskType, + VoiceConfig, ) from .types.model import Model from .types.model_service import ( @@ -338,6 +341,7 @@ "Part", "Permission", "PermissionServiceClient", + "PrebuiltVoiceConfig", "PredictRequest", "PredictResponse", "PredictionServiceClient", @@ -355,6 +359,7 @@ "SearchEntryPoint", "Segment", "SemanticRetrieverConfig", + "SpeechConfig", "StringList", "TaskType", "TextCompletion", @@ -378,4 +383,5 @@ "UpdatePermissionRequest", "UpdateTunedModelRequest", "VideoMetadata", + "VoiceConfig", ) diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/gapic_version.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/gapic_version.py index 0b6dbde2b051..558c8aab67c5 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/gapic_version.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "0.6.14" # {x-release-please-version} +__version__ = "0.0.0" # {x-release-please-version} diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/file_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/file_service/async_client.py index cc4d74a856c1..a466720b829e 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/file_service/async_client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/file_service/async_client.py @@ -513,6 +513,8 @@ async def sample_get_file(): Returns: google.ai.generativelanguage_v1beta.types.File: A file uploaded to the API. + Next ID: 15 + """ # Create or coerce a protobuf request object. # - Quick check: If we got a request object, we should *not* have diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/file_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/file_service/client.py index 3de779357670..f4e97c3210b9 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/file_service/client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/file_service/client.py @@ -895,6 +895,8 @@ def sample_get_file(): Returns: google.ai.generativelanguage_v1beta.types.File: A file uploaded to the API. + Next ID: 15 + """ # Create or coerce a protobuf request object. # - Quick check: If we got a request object, we should *not* have diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/file_service/transports/rest.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/file_service/transports/rest.py index ab839007a690..82c4b919bee6 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/file_service/transports/rest.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/file_service/transports/rest.py @@ -630,6 +630,8 @@ def __call__( Returns: ~.file.File: A file uploaded to the API. + Next ID: 15 + """ http_options = ( diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/generative_service/async_client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/generative_service/async_client.py index 535e5cff0d45..efc2b8b8b21d 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/generative_service/async_client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/generative_service/async_client.py @@ -357,7 +357,7 @@ async def sample_generate_content(): Required. The name of the ``Model`` to use for generating the completion. - Format: ``name=models/{model}``. + Format: ``models/{model}``. This corresponds to the ``model`` field on the ``request`` instance; if ``request`` is provided, this @@ -671,7 +671,7 @@ async def sample_stream_generate_content(): Required. The name of the ``Model`` to use for generating the completion. - Format: ``name=models/{model}``. + Format: ``models/{model}``. This corresponds to the ``model`` field on the ``request`` instance; if ``request`` is provided, this diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/generative_service/client.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/generative_service/client.py index 02c6b988a806..b4588a5cb920 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/generative_service/client.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/services/generative_service/client.py @@ -755,7 +755,7 @@ def sample_generate_content(): Required. The name of the ``Model`` to use for generating the completion. - Format: ``name=models/{model}``. + Format: ``models/{model}``. This corresponds to the ``model`` field on the ``request`` instance; if ``request`` is provided, this @@ -1063,7 +1063,7 @@ def sample_stream_generate_content(): Required. The name of the ``Model`` to use for generating the completion. - Format: ``name=models/{model}``. + Format: ``models/{model}``. This corresponds to the ``model`` field on the ``request`` instance; if ``request`` is provided, this diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/__init__.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/__init__.py index 9dd7a564142d..19e8bf5d18bb 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/__init__.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/__init__.py @@ -82,11 +82,14 @@ GroundingMetadata, GroundingSupport, LogprobsResult, + PrebuiltVoiceConfig, RetrievalMetadata, SearchEntryPoint, Segment, SemanticRetrieverConfig, + SpeechConfig, TaskType, + VoiceConfig, ) from .model import Model from .model_service import ( @@ -247,10 +250,13 @@ "GroundingMetadata", "GroundingSupport", "LogprobsResult", + "PrebuiltVoiceConfig", "RetrievalMetadata", "SearchEntryPoint", "Segment", "SemanticRetrieverConfig", + "SpeechConfig", + "VoiceConfig", "TaskType", "Model", "CreateTunedModelMetadata", diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/content.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/content.py index 6b5d37cd15ce..04712f6f88df 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/content.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/content.py @@ -371,8 +371,18 @@ class Tool(proto.Message): code_execution (google.ai.generativelanguage_v1beta.types.CodeExecution): Optional. Enables the model to execute code as part of generation. + google_search (google.ai.generativelanguage_v1beta.types.Tool.GoogleSearch): + Optional. GoogleSearch tool type. + Tool to support Google Search in Model. Powered + by Google. """ + class GoogleSearch(proto.Message): + r"""GoogleSearch tool type. + Tool to support Google Search in Model. Powered by Google. + + """ + function_declarations: MutableSequence["FunctionDeclaration"] = proto.RepeatedField( proto.MESSAGE, number=1, @@ -388,6 +398,11 @@ class Tool(proto.Message): number=3, message="CodeExecution", ) + google_search: GoogleSearch = proto.Field( + proto.MESSAGE, + number=4, + message=GoogleSearch, + ) class GoogleSearchRetrieval(proto.Message): @@ -560,6 +575,14 @@ class FunctionDeclaration(proto.Message): parameter. This field is a member of `oneof`_ ``_parameters``. + response (google.ai.generativelanguage_v1beta.types.Schema): + Optional. Describes the output from this + function in JSON Schema format. Reflects the + Open API 3.03 Response Object. The Schema + defines the type used for the response value of + the function. + + This field is a member of `oneof`_ ``_response``. """ name: str = proto.Field( @@ -576,6 +599,12 @@ class FunctionDeclaration(proto.Message): optional=True, message="Schema", ) + response: "Schema" = proto.Field( + proto.MESSAGE, + number=4, + optional=True, + message="Schema", + ) class FunctionCall(proto.Message): @@ -587,6 +616,10 @@ class FunctionCall(proto.Message): .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields Attributes: + id (str): + Optional. The unique id of the function call. If populated, + the client to execute the ``function_call`` and return the + response with the matching ``id``. name (str): Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores @@ -598,6 +631,10 @@ class FunctionCall(proto.Message): This field is a member of `oneof`_ ``_args``. """ + id: str = proto.Field( + proto.STRING, + number=3, + ) name: str = proto.Field( proto.STRING, number=1, @@ -618,6 +655,10 @@ class FunctionResponse(proto.Message): made based on model prediction. Attributes: + id (str): + Optional. The id of the function call this response is for. + Populated by the client to match the corresponding function + call ``id``. name (str): Required. The name of the function to call. Must be a-z, A-Z, 0-9, or contain underscores @@ -627,6 +668,10 @@ class FunctionResponse(proto.Message): object format. """ + id: str = proto.Field( + proto.STRING, + number=3, + ) name: str = proto.Field( proto.STRING, number=1, diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/file.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/file.py index 387d00aafbf7..b5621298c671 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/file.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/file.py @@ -33,6 +33,8 @@ class File(proto.Message): r"""A file uploaded to the API. + Next ID: 15 + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/generative_service.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/generative_service.py index 27a344b99ec5..c66908d84188 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/generative_service.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta/types/generative_service.py @@ -28,6 +28,9 @@ manifest={ "TaskType", "GenerateContentRequest", + "PrebuiltVoiceConfig", + "VoiceConfig", + "SpeechConfig", "GenerationConfig", "SemanticRetrieverConfig", "GenerateContentResponse", @@ -103,7 +106,7 @@ class GenerateContentRequest(proto.Message): Required. The name of the ``Model`` to use for generating the completion. - Format: ``name=models/{model}``. + Format: ``models/{model}``. system_instruction (google.ai.generativelanguage_v1beta.types.Content): Optional. Developer set `system instruction(s) `__. @@ -153,8 +156,8 @@ class GenerateContentRequest(proto.Message): will use the default safety setting for that category. Harm categories HARM_CATEGORY_HATE_SPEECH, HARM_CATEGORY_SEXUALLY_EXPLICIT, - HARM_CATEGORY_DANGEROUS_CONTENT, HARM_CATEGORY_HARASSMENT - are supported. Refer to the + HARM_CATEGORY_DANGEROUS_CONTENT, HARM_CATEGORY_HARASSMENT, + HARM_CATEGORY_CIVIC_INTEGRITY are supported. Refer to the `guide `__ for detailed information on available safety settings. Also refer to the `Safety @@ -218,6 +221,61 @@ class GenerateContentRequest(proto.Message): ) +class PrebuiltVoiceConfig(proto.Message): + r"""The configuration for the prebuilt speaker to use. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + voice_name (str): + The name of the preset voice to use. + + This field is a member of `oneof`_ ``_voice_name``. + """ + + voice_name: str = proto.Field( + proto.STRING, + number=1, + optional=True, + ) + + +class VoiceConfig(proto.Message): + r"""The configuration for the voice to use. + + .. _oneof: https://proto-plus-python.readthedocs.io/en/stable/fields.html#oneofs-mutually-exclusive-fields + + Attributes: + prebuilt_voice_config (google.ai.generativelanguage_v1beta.types.PrebuiltVoiceConfig): + The configuration for the prebuilt voice to + use. + + This field is a member of `oneof`_ ``voice_config``. + """ + + prebuilt_voice_config: "PrebuiltVoiceConfig" = proto.Field( + proto.MESSAGE, + number=1, + oneof="voice_config", + message="PrebuiltVoiceConfig", + ) + + +class SpeechConfig(proto.Message): + r"""The speech generation config. + + Attributes: + voice_config (google.ai.generativelanguage_v1beta.types.VoiceConfig): + The configuration for the speaker to use. + """ + + voice_config: "VoiceConfig" = proto.Field( + proto.MESSAGE, + number=1, + message="VoiceConfig", + ) + + class GenerationConfig(proto.Message): r"""Configuration options for model generation and outputs. Not all parameters are configurable for every model. @@ -363,8 +421,49 @@ class GenerationConfig(proto.Message): [Candidate.logprobs_result][google.ai.generativelanguage.v1beta.Candidate.logprobs_result]. This field is a member of `oneof`_ ``_logprobs``. + enable_enhanced_civic_answers (bool): + Optional. Enables enhanced civic answers. It + may not be available for all models. + + This field is a member of `oneof`_ ``_enable_enhanced_civic_answers``. + response_modalities (MutableSequence[google.ai.generativelanguage_v1beta.types.GenerationConfig.Modality]): + Optional. The requested modalities of the + response. Represents the set of modalities that + the model can return, and should be expected in + the response. This is an exact match to the + modalities of the response. + + A model may have multiple combinations of + supported modalities. If the requested + modalities do not match any of the supported + combinations, an error will be returned. + + An empty list is equivalent to requesting only + text. + speech_config (google.ai.generativelanguage_v1beta.types.SpeechConfig): + Optional. The speech generation config. + + This field is a member of `oneof`_ ``_speech_config``. """ + class Modality(proto.Enum): + r"""Supported modalities of the response. + + Values: + MODALITY_UNSPECIFIED (0): + Default value. + TEXT (1): + Indicates the model should return text. + IMAGE (2): + Indicates the model should return images. + AUDIO (3): + Indicates the model should return audio. + """ + MODALITY_UNSPECIFIED = 0 + TEXT = 1 + IMAGE = 2 + AUDIO = 3 + candidate_count: int = proto.Field( proto.INT32, number=1, @@ -423,6 +522,22 @@ class GenerationConfig(proto.Message): number=18, optional=True, ) + enable_enhanced_civic_answers: bool = proto.Field( + proto.BOOL, + number=19, + optional=True, + ) + response_modalities: MutableSequence[Modality] = proto.RepeatedField( + proto.ENUM, + number=20, + enum=Modality, + ) + speech_config: "SpeechConfig" = proto.Field( + proto.MESSAGE, + number=21, + optional=True, + message="SpeechConfig", + ) class SemanticRetrieverConfig(proto.Message): @@ -537,12 +652,16 @@ class BlockReason(proto.Enum): included from the terminology blocklist. PROHIBITED_CONTENT (4): Prompt was blocked due to prohibited content. + IMAGE_SAFETY (5): + Candidates blocked due to unsafe image + generation content. """ BLOCK_REASON_UNSPECIFIED = 0 SAFETY = 1 OTHER = 2 BLOCKLIST = 3 PROHIBITED_CONTENT = 4 + IMAGE_SAFETY = 5 block_reason: "GenerateContentResponse.PromptFeedback.BlockReason" = ( proto.Field( @@ -700,6 +819,9 @@ class FinishReason(proto.Enum): MALFORMED_FUNCTION_CALL (10): The function call generated by the model is invalid. + IMAGE_SAFETY (11): + Token generation stopped because generated + images contain safety violations. """ FINISH_REASON_UNSPECIFIED = 0 STOP = 1 @@ -712,6 +834,7 @@ class FinishReason(proto.Enum): PROHIBITED_CONTENT = 8 SPII = 9 MALFORMED_FUNCTION_CALL = 10 + IMAGE_SAFETY = 11 index: int = proto.Field( proto.INT32, diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/gapic_version.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/gapic_version.py index 0b6dbde2b051..558c8aab67c5 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/gapic_version.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta2/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "0.6.14" # {x-release-please-version} +__version__ = "0.0.0" # {x-release-please-version} diff --git a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/gapic_version.py b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/gapic_version.py index 0b6dbde2b051..558c8aab67c5 100644 --- a/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/gapic_version.py +++ b/packages/google-ai-generativelanguage/google/ai/generativelanguage_v1beta3/gapic_version.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -__version__ = "0.6.14" # {x-release-please-version} +__version__ = "0.0.0" # {x-release-please-version} diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_create_cached_content_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_create_cached_content_async.py new file mode 100644 index 000000000000..99169d1f155c --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_create_cached_content_async.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateCachedContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_CacheService_CreateCachedContent_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_create_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateCachedContentRequest( + ) + + # Make the request + response = await client.create_cached_content(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_CacheService_CreateCachedContent_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_create_cached_content_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_create_cached_content_sync.py new file mode 100644 index 000000000000..a1b3a5386fce --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_create_cached_content_sync.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateCachedContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_CacheService_CreateCachedContent_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_create_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateCachedContentRequest( + ) + + # Make the request + response = client.create_cached_content(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_CacheService_CreateCachedContent_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_delete_cached_content_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_delete_cached_content_async.py new file mode 100644 index 000000000000..af9ef6361c2f --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_delete_cached_content_async.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteCachedContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_CacheService_DeleteCachedContent_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_delete_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteCachedContentRequest( + name="name_value", + ) + + # Make the request + await client.delete_cached_content(request=request) + + +# [END generativelanguage_v1alpha_generated_CacheService_DeleteCachedContent_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_delete_cached_content_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_delete_cached_content_sync.py new file mode 100644 index 000000000000..9735af0fb8ea --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_delete_cached_content_sync.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteCachedContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_CacheService_DeleteCachedContent_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_delete_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteCachedContentRequest( + name="name_value", + ) + + # Make the request + client.delete_cached_content(request=request) + + +# [END generativelanguage_v1alpha_generated_CacheService_DeleteCachedContent_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_get_cached_content_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_get_cached_content_async.py new file mode 100644 index 000000000000..2b328840fe16 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_get_cached_content_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetCachedContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_CacheService_GetCachedContent_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_get_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetCachedContentRequest( + name="name_value", + ) + + # Make the request + response = await client.get_cached_content(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_CacheService_GetCachedContent_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_get_cached_content_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_get_cached_content_sync.py new file mode 100644 index 000000000000..2ed52f862be9 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_get_cached_content_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetCachedContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_CacheService_GetCachedContent_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_get_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetCachedContentRequest( + name="name_value", + ) + + # Make the request + response = client.get_cached_content(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_CacheService_GetCachedContent_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_list_cached_contents_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_list_cached_contents_async.py new file mode 100644 index 000000000000..1641340b03eb --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_list_cached_contents_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListCachedContents +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_CacheService_ListCachedContents_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_list_cached_contents(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListCachedContentsRequest( + ) + + # Make the request + page_result = client.list_cached_contents(request=request) + + # Handle the response + async for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_CacheService_ListCachedContents_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_list_cached_contents_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_list_cached_contents_sync.py new file mode 100644 index 000000000000..48987878ea18 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_list_cached_contents_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListCachedContents +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_CacheService_ListCachedContents_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_list_cached_contents(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListCachedContentsRequest( + ) + + # Make the request + page_result = client.list_cached_contents(request=request) + + # Handle the response + for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_CacheService_ListCachedContents_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_update_cached_content_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_update_cached_content_async.py new file mode 100644 index 000000000000..75ccf93a8333 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_update_cached_content_async.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateCachedContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_CacheService_UpdateCachedContent_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_update_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateCachedContentRequest( + ) + + # Make the request + response = await client.update_cached_content(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_CacheService_UpdateCachedContent_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_update_cached_content_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_update_cached_content_sync.py new file mode 100644 index 000000000000..8d08f40b1964 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_cache_service_update_cached_content_sync.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateCachedContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_CacheService_UpdateCachedContent_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_update_cached_content(): + # Create a client + client = generativelanguage_v1alpha.CacheServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateCachedContentRequest( + ) + + # Make the request + response = client.update_cached_content(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_CacheService_UpdateCachedContent_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_count_message_tokens_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_count_message_tokens_async.py new file mode 100644 index 000000000000..d500fa672e3f --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_count_message_tokens_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CountMessageTokens +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_DiscussService_CountMessageTokens_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_count_message_tokens(): + # Create a client + client = generativelanguage_v1alpha.DiscussServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1alpha.CountMessageTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.count_message_tokens(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_DiscussService_CountMessageTokens_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_count_message_tokens_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_count_message_tokens_sync.py new file mode 100644 index 000000000000..223eeae8a2b0 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_count_message_tokens_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CountMessageTokens +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_DiscussService_CountMessageTokens_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_count_message_tokens(): + # Create a client + client = generativelanguage_v1alpha.DiscussServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1alpha.CountMessageTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.count_message_tokens(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_DiscussService_CountMessageTokens_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_generate_message_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_generate_message_async.py new file mode 100644 index 000000000000..a3f95f2e9bc0 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_generate_message_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateMessage +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_DiscussService_GenerateMessage_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_generate_message(): + # Create a client + client = generativelanguage_v1alpha.DiscussServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1alpha.GenerateMessageRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.generate_message(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_DiscussService_GenerateMessage_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_generate_message_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_generate_message_sync.py new file mode 100644 index 000000000000..1f9247a0e59f --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_discuss_service_generate_message_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateMessage +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_DiscussService_GenerateMessage_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_generate_message(): + # Create a client + client = generativelanguage_v1alpha.DiscussServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.MessagePrompt() + prompt.messages.content = "content_value" + + request = generativelanguage_v1alpha.GenerateMessageRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.generate_message(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_DiscussService_GenerateMessage_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_create_file_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_create_file_async.py new file mode 100644 index 000000000000..f43d178c8c6b --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_create_file_async.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateFile +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_FileService_CreateFile_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_create_file(): + # Create a client + client = generativelanguage_v1alpha.FileServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateFileRequest( + ) + + # Make the request + response = await client.create_file(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_FileService_CreateFile_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_create_file_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_create_file_sync.py new file mode 100644 index 000000000000..c5a52005687f --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_create_file_sync.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateFile +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_FileService_CreateFile_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_create_file(): + # Create a client + client = generativelanguage_v1alpha.FileServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateFileRequest( + ) + + # Make the request + response = client.create_file(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_FileService_CreateFile_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_delete_file_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_delete_file_async.py new file mode 100644 index 000000000000..43417ec52240 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_delete_file_async.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteFile +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_FileService_DeleteFile_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_delete_file(): + # Create a client + client = generativelanguage_v1alpha.FileServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteFileRequest( + name="name_value", + ) + + # Make the request + await client.delete_file(request=request) + + +# [END generativelanguage_v1alpha_generated_FileService_DeleteFile_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_delete_file_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_delete_file_sync.py new file mode 100644 index 000000000000..d82d529013ed --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_delete_file_sync.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteFile +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_FileService_DeleteFile_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_delete_file(): + # Create a client + client = generativelanguage_v1alpha.FileServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteFileRequest( + name="name_value", + ) + + # Make the request + client.delete_file(request=request) + + +# [END generativelanguage_v1alpha_generated_FileService_DeleteFile_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_get_file_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_get_file_async.py new file mode 100644 index 000000000000..a412847208b9 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_get_file_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetFile +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_FileService_GetFile_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_get_file(): + # Create a client + client = generativelanguage_v1alpha.FileServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetFileRequest( + name="name_value", + ) + + # Make the request + response = await client.get_file(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_FileService_GetFile_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_get_file_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_get_file_sync.py new file mode 100644 index 000000000000..a079c262d8d5 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_get_file_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetFile +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_FileService_GetFile_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_get_file(): + # Create a client + client = generativelanguage_v1alpha.FileServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetFileRequest( + name="name_value", + ) + + # Make the request + response = client.get_file(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_FileService_GetFile_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_list_files_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_list_files_async.py new file mode 100644 index 000000000000..ef2e48665528 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_list_files_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListFiles +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_FileService_ListFiles_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_list_files(): + # Create a client + client = generativelanguage_v1alpha.FileServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListFilesRequest( + ) + + # Make the request + page_result = client.list_files(request=request) + + # Handle the response + async for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_FileService_ListFiles_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_list_files_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_list_files_sync.py new file mode 100644 index 000000000000..959d14b0f484 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_file_service_list_files_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListFiles +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_FileService_ListFiles_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_list_files(): + # Create a client + client = generativelanguage_v1alpha.FileServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListFilesRequest( + ) + + # Make the request + page_result = client.list_files(request=request) + + # Handle the response + for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_FileService_ListFiles_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_batch_embed_contents_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_batch_embed_contents_async.py new file mode 100644 index 000000000000..101edd5d7237 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_batch_embed_contents_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BatchEmbedContents +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_GenerativeService_BatchEmbedContents_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_batch_embed_contents(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceAsyncClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.EmbedContentRequest() + requests.model = "model_value" + + request = generativelanguage_v1alpha.BatchEmbedContentsRequest( + model="model_value", + requests=requests, + ) + + # Make the request + response = await client.batch_embed_contents(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_GenerativeService_BatchEmbedContents_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_batch_embed_contents_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_batch_embed_contents_sync.py new file mode 100644 index 000000000000..cc778d437a35 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_batch_embed_contents_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BatchEmbedContents +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_GenerativeService_BatchEmbedContents_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_batch_embed_contents(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.EmbedContentRequest() + requests.model = "model_value" + + request = generativelanguage_v1alpha.BatchEmbedContentsRequest( + model="model_value", + requests=requests, + ) + + # Make the request + response = client.batch_embed_contents(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_GenerativeService_BatchEmbedContents_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_bidi_generate_content_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_bidi_generate_content_async.py new file mode 100644 index 000000000000..64acc497b4c7 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_bidi_generate_content_async.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BidiGenerateContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_GenerativeService_BidiGenerateContent_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_bidi_generate_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceAsyncClient() + + # Initialize request argument(s) + setup = generativelanguage_v1alpha.BidiGenerateContentSetup() + setup.model = "model_value" + + request = generativelanguage_v1alpha.BidiGenerateContentClientMessage( + setup=setup, + ) + + # This method expects an iterator which contains + # 'generativelanguage_v1alpha.BidiGenerateContentClientMessage' objects + # Here we create a generator that yields a single `request` for + # demonstrative purposes. + requests = [request] + + def request_generator(): + for request in requests: + yield request + + # Make the request + stream = await client.bidi_generate_content(requests=request_generator()) + + # Handle the response + async for response in stream: + print(response) + +# [END generativelanguage_v1alpha_generated_GenerativeService_BidiGenerateContent_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_bidi_generate_content_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_bidi_generate_content_sync.py new file mode 100644 index 000000000000..3327c9b41c86 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_bidi_generate_content_sync.py @@ -0,0 +1,66 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BidiGenerateContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_GenerativeService_BidiGenerateContent_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_bidi_generate_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceClient() + + # Initialize request argument(s) + setup = generativelanguage_v1alpha.BidiGenerateContentSetup() + setup.model = "model_value" + + request = generativelanguage_v1alpha.BidiGenerateContentClientMessage( + setup=setup, + ) + + # This method expects an iterator which contains + # 'generativelanguage_v1alpha.BidiGenerateContentClientMessage' objects + # Here we create a generator that yields a single `request` for + # demonstrative purposes. + requests = [request] + + def request_generator(): + for request in requests: + yield request + + # Make the request + stream = client.bidi_generate_content(requests=request_generator()) + + # Handle the response + for response in stream: + print(response) + +# [END generativelanguage_v1alpha_generated_GenerativeService_BidiGenerateContent_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_count_tokens_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_count_tokens_async.py new file mode 100644 index 000000000000..bcfa04fdd133 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_count_tokens_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CountTokens +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_GenerativeService_CountTokens_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_count_tokens(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CountTokensRequest( + model="model_value", + ) + + # Make the request + response = await client.count_tokens(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_GenerativeService_CountTokens_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_count_tokens_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_count_tokens_sync.py new file mode 100644 index 000000000000..01fd050b7bcd --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_count_tokens_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CountTokens +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_GenerativeService_CountTokens_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_count_tokens(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CountTokensRequest( + model="model_value", + ) + + # Make the request + response = client.count_tokens(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_GenerativeService_CountTokens_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_embed_content_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_embed_content_async.py new file mode 100644 index 000000000000..3ef00ab390ed --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_embed_content_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for EmbedContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_GenerativeService_EmbedContent_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_embed_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.EmbedContentRequest( + model="model_value", + ) + + # Make the request + response = await client.embed_content(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_GenerativeService_EmbedContent_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_embed_content_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_embed_content_sync.py new file mode 100644 index 000000000000..713e100205ac --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_embed_content_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for EmbedContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_GenerativeService_EmbedContent_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_embed_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.EmbedContentRequest( + model="model_value", + ) + + # Make the request + response = client.embed_content(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_GenerativeService_EmbedContent_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_answer_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_answer_async.py new file mode 100644 index 000000000000..f1c6dac52aa3 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_answer_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateAnswer +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_GenerativeService_GenerateAnswer_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_generate_answer(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GenerateAnswerRequest( + model="model_value", + answer_style="VERBOSE", + ) + + # Make the request + response = await client.generate_answer(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_GenerativeService_GenerateAnswer_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_answer_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_answer_sync.py new file mode 100644 index 000000000000..f535aee82342 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_answer_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateAnswer +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_GenerativeService_GenerateAnswer_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_generate_answer(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GenerateAnswerRequest( + model="model_value", + answer_style="VERBOSE", + ) + + # Make the request + response = client.generate_answer(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_GenerativeService_GenerateAnswer_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_content_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_content_async.py new file mode 100644 index 000000000000..dadd6494fc01 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_content_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_GenerativeService_GenerateContent_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_generate_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GenerateContentRequest( + model="model_value", + ) + + # Make the request + response = await client.generate_content(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_GenerativeService_GenerateContent_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_content_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_content_sync.py new file mode 100644 index 000000000000..e0f37948b76d --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_generate_content_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_GenerativeService_GenerateContent_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_generate_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GenerateContentRequest( + model="model_value", + ) + + # Make the request + response = client.generate_content(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_GenerativeService_GenerateContent_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_stream_generate_content_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_stream_generate_content_async.py new file mode 100644 index 000000000000..cf3958d6400f --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_stream_generate_content_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for StreamGenerateContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_GenerativeService_StreamGenerateContent_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_stream_generate_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GenerateContentRequest( + model="model_value", + ) + + # Make the request + stream = await client.stream_generate_content(request=request) + + # Handle the response + async for response in stream: + print(response) + +# [END generativelanguage_v1alpha_generated_GenerativeService_StreamGenerateContent_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_stream_generate_content_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_stream_generate_content_sync.py new file mode 100644 index 000000000000..057f8d11ff2e --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_generative_service_stream_generate_content_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for StreamGenerateContent +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_GenerativeService_StreamGenerateContent_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_stream_generate_content(): + # Create a client + client = generativelanguage_v1alpha.GenerativeServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GenerateContentRequest( + model="model_value", + ) + + # Make the request + stream = client.stream_generate_content(request=request) + + # Handle the response + for response in stream: + print(response) + +# [END generativelanguage_v1alpha_generated_GenerativeService_StreamGenerateContent_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_create_tuned_model_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_create_tuned_model_async.py new file mode 100644 index 000000000000..8b78c9a1fd41 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_create_tuned_model_async.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_ModelService_CreateTunedModel_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_create_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateTunedModelRequest( + ) + + # Make the request + operation = client.create_tuned_model(request=request) + + print("Waiting for operation to complete...") + + response = (await operation).result() + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_ModelService_CreateTunedModel_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_create_tuned_model_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_create_tuned_model_sync.py new file mode 100644 index 000000000000..5242694677d7 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_create_tuned_model_sync.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_ModelService_CreateTunedModel_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_create_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateTunedModelRequest( + ) + + # Make the request + operation = client.create_tuned_model(request=request) + + print("Waiting for operation to complete...") + + response = operation.result() + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_ModelService_CreateTunedModel_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_delete_tuned_model_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_delete_tuned_model_async.py new file mode 100644 index 000000000000..5f154afa887e --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_delete_tuned_model_async.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_ModelService_DeleteTunedModel_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_delete_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteTunedModelRequest( + name="name_value", + ) + + # Make the request + await client.delete_tuned_model(request=request) + + +# [END generativelanguage_v1alpha_generated_ModelService_DeleteTunedModel_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_delete_tuned_model_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_delete_tuned_model_sync.py new file mode 100644 index 000000000000..cc63534199d0 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_delete_tuned_model_sync.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_ModelService_DeleteTunedModel_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_delete_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteTunedModelRequest( + name="name_value", + ) + + # Make the request + client.delete_tuned_model(request=request) + + +# [END generativelanguage_v1alpha_generated_ModelService_DeleteTunedModel_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_model_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_model_async.py new file mode 100644 index 000000000000..2e492bb9af48 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_model_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_ModelService_GetModel_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_get_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetModelRequest( + name="name_value", + ) + + # Make the request + response = await client.get_model(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_ModelService_GetModel_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_model_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_model_sync.py new file mode 100644 index 000000000000..f2f357bfc072 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_model_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_ModelService_GetModel_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_get_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetModelRequest( + name="name_value", + ) + + # Make the request + response = client.get_model(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_ModelService_GetModel_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_tuned_model_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_tuned_model_async.py new file mode 100644 index 000000000000..2a2cd50e76b1 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_tuned_model_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_ModelService_GetTunedModel_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_get_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetTunedModelRequest( + name="name_value", + ) + + # Make the request + response = await client.get_tuned_model(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_ModelService_GetTunedModel_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_tuned_model_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_tuned_model_sync.py new file mode 100644 index 000000000000..2d0d781c88d0 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_get_tuned_model_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_ModelService_GetTunedModel_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_get_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetTunedModelRequest( + name="name_value", + ) + + # Make the request + response = client.get_tuned_model(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_ModelService_GetTunedModel_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_models_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_models_async.py new file mode 100644 index 000000000000..345322778756 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_models_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListModels +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_ModelService_ListModels_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_list_models(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListModelsRequest( + ) + + # Make the request + page_result = client.list_models(request=request) + + # Handle the response + async for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_ModelService_ListModels_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_models_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_models_sync.py new file mode 100644 index 000000000000..9a8e06503616 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_models_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListModels +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_ModelService_ListModels_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_list_models(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListModelsRequest( + ) + + # Make the request + page_result = client.list_models(request=request) + + # Handle the response + for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_ModelService_ListModels_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_tuned_models_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_tuned_models_async.py new file mode 100644 index 000000000000..61ae3811566a --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_tuned_models_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListTunedModels +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_ModelService_ListTunedModels_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_list_tuned_models(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListTunedModelsRequest( + ) + + # Make the request + page_result = client.list_tuned_models(request=request) + + # Handle the response + async for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_ModelService_ListTunedModels_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_tuned_models_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_tuned_models_sync.py new file mode 100644 index 000000000000..70cda88c1414 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_list_tuned_models_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListTunedModels +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_ModelService_ListTunedModels_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_list_tuned_models(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListTunedModelsRequest( + ) + + # Make the request + page_result = client.list_tuned_models(request=request) + + # Handle the response + for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_ModelService_ListTunedModels_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_update_tuned_model_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_update_tuned_model_async.py new file mode 100644 index 000000000000..9fa52de8a1b0 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_update_tuned_model_async.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_ModelService_UpdateTunedModel_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_update_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateTunedModelRequest( + ) + + # Make the request + response = await client.update_tuned_model(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_ModelService_UpdateTunedModel_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_update_tuned_model_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_update_tuned_model_sync.py new file mode 100644 index 000000000000..c7092fd8ca61 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_model_service_update_tuned_model_sync.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateTunedModel +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_ModelService_UpdateTunedModel_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_update_tuned_model(): + # Create a client + client = generativelanguage_v1alpha.ModelServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateTunedModelRequest( + ) + + # Make the request + response = client.update_tuned_model(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_ModelService_UpdateTunedModel_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_create_permission_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_create_permission_async.py new file mode 100644 index 000000000000..80e2d7160063 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_create_permission_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreatePermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_PermissionService_CreatePermission_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_create_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreatePermissionRequest( + parent="parent_value", + ) + + # Make the request + response = await client.create_permission(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_PermissionService_CreatePermission_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_create_permission_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_create_permission_sync.py new file mode 100644 index 000000000000..587f8302cee5 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_create_permission_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreatePermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_PermissionService_CreatePermission_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_create_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreatePermissionRequest( + parent="parent_value", + ) + + # Make the request + response = client.create_permission(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_PermissionService_CreatePermission_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_delete_permission_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_delete_permission_async.py new file mode 100644 index 000000000000..ce903e5cf86e --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_delete_permission_async.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeletePermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_PermissionService_DeletePermission_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_delete_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeletePermissionRequest( + name="name_value", + ) + + # Make the request + await client.delete_permission(request=request) + + +# [END generativelanguage_v1alpha_generated_PermissionService_DeletePermission_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_delete_permission_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_delete_permission_sync.py new file mode 100644 index 000000000000..c1e4b4080053 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_delete_permission_sync.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeletePermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_PermissionService_DeletePermission_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_delete_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeletePermissionRequest( + name="name_value", + ) + + # Make the request + client.delete_permission(request=request) + + +# [END generativelanguage_v1alpha_generated_PermissionService_DeletePermission_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_get_permission_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_get_permission_async.py new file mode 100644 index 000000000000..8b5ff9e2f79d --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_get_permission_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetPermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_PermissionService_GetPermission_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_get_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetPermissionRequest( + name="name_value", + ) + + # Make the request + response = await client.get_permission(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_PermissionService_GetPermission_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_get_permission_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_get_permission_sync.py new file mode 100644 index 000000000000..2fa4ed614567 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_get_permission_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetPermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_PermissionService_GetPermission_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_get_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetPermissionRequest( + name="name_value", + ) + + # Make the request + response = client.get_permission(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_PermissionService_GetPermission_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_list_permissions_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_list_permissions_async.py new file mode 100644 index 000000000000..01ef5734f136 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_list_permissions_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListPermissions +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_PermissionService_ListPermissions_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_list_permissions(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListPermissionsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_permissions(request=request) + + # Handle the response + async for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_PermissionService_ListPermissions_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_list_permissions_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_list_permissions_sync.py new file mode 100644 index 000000000000..3b52a1da909b --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_list_permissions_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListPermissions +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_PermissionService_ListPermissions_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_list_permissions(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListPermissionsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_permissions(request=request) + + # Handle the response + for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_PermissionService_ListPermissions_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_transfer_ownership_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_transfer_ownership_async.py new file mode 100644 index 000000000000..968f6c2023ef --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_transfer_ownership_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for TransferOwnership +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_PermissionService_TransferOwnership_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_transfer_ownership(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.TransferOwnershipRequest( + name="name_value", + email_address="email_address_value", + ) + + # Make the request + response = await client.transfer_ownership(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_PermissionService_TransferOwnership_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_transfer_ownership_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_transfer_ownership_sync.py new file mode 100644 index 000000000000..d2b3e909d6c8 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_transfer_ownership_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for TransferOwnership +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_PermissionService_TransferOwnership_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_transfer_ownership(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.TransferOwnershipRequest( + name="name_value", + email_address="email_address_value", + ) + + # Make the request + response = client.transfer_ownership(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_PermissionService_TransferOwnership_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_update_permission_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_update_permission_async.py new file mode 100644 index 000000000000..e956061d1c31 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_update_permission_async.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdatePermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_PermissionService_UpdatePermission_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_update_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdatePermissionRequest( + ) + + # Make the request + response = await client.update_permission(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_PermissionService_UpdatePermission_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_update_permission_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_update_permission_sync.py new file mode 100644 index 000000000000..1e5fbe25f536 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_permission_service_update_permission_sync.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdatePermission +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_PermissionService_UpdatePermission_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_update_permission(): + # Create a client + client = generativelanguage_v1alpha.PermissionServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdatePermissionRequest( + ) + + # Make the request + response = client.update_permission(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_PermissionService_UpdatePermission_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_prediction_service_predict_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_prediction_service_predict_async.py new file mode 100644 index 000000000000..2f3022b0a953 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_prediction_service_predict_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for Predict +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_PredictionService_Predict_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_predict(): + # Create a client + client = generativelanguage_v1alpha.PredictionServiceAsyncClient() + + # Initialize request argument(s) + instances = generativelanguage_v1alpha.Value() + instances.null_value = "NULL_VALUE" + + request = generativelanguage_v1alpha.PredictRequest( + model="model_value", + instances=instances, + ) + + # Make the request + response = await client.predict(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_PredictionService_Predict_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_prediction_service_predict_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_prediction_service_predict_sync.py new file mode 100644 index 000000000000..800ae88ad788 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_prediction_service_predict_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for Predict +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_PredictionService_Predict_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_predict(): + # Create a client + client = generativelanguage_v1alpha.PredictionServiceClient() + + # Initialize request argument(s) + instances = generativelanguage_v1alpha.Value() + instances.null_value = "NULL_VALUE" + + request = generativelanguage_v1alpha.PredictRequest( + model="model_value", + instances=instances, + ) + + # Make the request + response = client.predict(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_PredictionService_Predict_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_create_chunks_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_create_chunks_async.py new file mode 100644 index 000000000000..295179010026 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_create_chunks_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BatchCreateChunks +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_BatchCreateChunks_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_batch_create_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.CreateChunkRequest() + requests.parent = "parent_value" + requests.chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.BatchCreateChunksRequest( + requests=requests, + ) + + # Make the request + response = await client.batch_create_chunks(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_BatchCreateChunks_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_create_chunks_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_create_chunks_sync.py new file mode 100644 index 000000000000..6155e653e118 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_create_chunks_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BatchCreateChunks +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_BatchCreateChunks_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_batch_create_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.CreateChunkRequest() + requests.parent = "parent_value" + requests.chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.BatchCreateChunksRequest( + requests=requests, + ) + + # Make the request + response = client.batch_create_chunks(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_BatchCreateChunks_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_delete_chunks_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_delete_chunks_async.py new file mode 100644 index 000000000000..a83d4749248b --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_delete_chunks_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BatchDeleteChunks +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_BatchDeleteChunks_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_batch_delete_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.DeleteChunkRequest() + requests.name = "name_value" + + request = generativelanguage_v1alpha.BatchDeleteChunksRequest( + requests=requests, + ) + + # Make the request + await client.batch_delete_chunks(request=request) + + +# [END generativelanguage_v1alpha_generated_RetrieverService_BatchDeleteChunks_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_delete_chunks_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_delete_chunks_sync.py new file mode 100644 index 000000000000..6bc1526c61a5 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_delete_chunks_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BatchDeleteChunks +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_BatchDeleteChunks_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_batch_delete_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.DeleteChunkRequest() + requests.name = "name_value" + + request = generativelanguage_v1alpha.BatchDeleteChunksRequest( + requests=requests, + ) + + # Make the request + client.batch_delete_chunks(request=request) + + +# [END generativelanguage_v1alpha_generated_RetrieverService_BatchDeleteChunks_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_update_chunks_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_update_chunks_async.py new file mode 100644 index 000000000000..bbf4fdb530e9 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_update_chunks_async.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BatchUpdateChunks +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_BatchUpdateChunks_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_batch_update_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.UpdateChunkRequest() + requests.chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.BatchUpdateChunksRequest( + requests=requests, + ) + + # Make the request + response = await client.batch_update_chunks(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_BatchUpdateChunks_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_update_chunks_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_update_chunks_sync.py new file mode 100644 index 000000000000..2502a6617d1f --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_batch_update_chunks_sync.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BatchUpdateChunks +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_BatchUpdateChunks_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_batch_update_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + requests = generativelanguage_v1alpha.UpdateChunkRequest() + requests.chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.BatchUpdateChunksRequest( + requests=requests, + ) + + # Make the request + response = client.batch_update_chunks(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_BatchUpdateChunks_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_chunk_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_chunk_async.py new file mode 100644 index 000000000000..ffaa5adbf680 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_chunk_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateChunk +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_CreateChunk_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_create_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + chunk = generativelanguage_v1alpha.Chunk() + chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.CreateChunkRequest( + parent="parent_value", + chunk=chunk, + ) + + # Make the request + response = await client.create_chunk(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_CreateChunk_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_chunk_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_chunk_sync.py new file mode 100644 index 000000000000..da7f5dd66659 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_chunk_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateChunk +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_CreateChunk_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_create_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + chunk = generativelanguage_v1alpha.Chunk() + chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.CreateChunkRequest( + parent="parent_value", + chunk=chunk, + ) + + # Make the request + response = client.create_chunk(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_CreateChunk_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_corpus_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_corpus_async.py new file mode 100644 index 000000000000..4807923df231 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_corpus_async.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateCorpus +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_CreateCorpus_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_create_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateCorpusRequest( + ) + + # Make the request + response = await client.create_corpus(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_CreateCorpus_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_corpus_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_corpus_sync.py new file mode 100644 index 000000000000..ffbf03f3980c --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_corpus_sync.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateCorpus +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_CreateCorpus_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_create_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateCorpusRequest( + ) + + # Make the request + response = client.create_corpus(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_CreateCorpus_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_document_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_document_async.py new file mode 100644 index 000000000000..27690e826e2c --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_document_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateDocument +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_CreateDocument_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_create_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateDocumentRequest( + parent="parent_value", + ) + + # Make the request + response = await client.create_document(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_CreateDocument_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_document_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_document_sync.py new file mode 100644 index 000000000000..8b848860c70d --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_create_document_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CreateDocument +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_CreateDocument_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_create_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.CreateDocumentRequest( + parent="parent_value", + ) + + # Make the request + response = client.create_document(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_CreateDocument_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_chunk_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_chunk_async.py new file mode 100644 index 000000000000..67aa1505ea7f --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_chunk_async.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteChunk +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_DeleteChunk_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_delete_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteChunkRequest( + name="name_value", + ) + + # Make the request + await client.delete_chunk(request=request) + + +# [END generativelanguage_v1alpha_generated_RetrieverService_DeleteChunk_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_chunk_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_chunk_sync.py new file mode 100644 index 000000000000..3c3b40963e51 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_chunk_sync.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteChunk +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_DeleteChunk_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_delete_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteChunkRequest( + name="name_value", + ) + + # Make the request + client.delete_chunk(request=request) + + +# [END generativelanguage_v1alpha_generated_RetrieverService_DeleteChunk_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_corpus_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_corpus_async.py new file mode 100644 index 000000000000..86b2d27ac527 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_corpus_async.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteCorpus +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_DeleteCorpus_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_delete_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteCorpusRequest( + name="name_value", + ) + + # Make the request + await client.delete_corpus(request=request) + + +# [END generativelanguage_v1alpha_generated_RetrieverService_DeleteCorpus_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_corpus_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_corpus_sync.py new file mode 100644 index 000000000000..958c31c514a6 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_corpus_sync.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteCorpus +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_DeleteCorpus_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_delete_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteCorpusRequest( + name="name_value", + ) + + # Make the request + client.delete_corpus(request=request) + + +# [END generativelanguage_v1alpha_generated_RetrieverService_DeleteCorpus_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_document_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_document_async.py new file mode 100644 index 000000000000..de5657b962c0 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_document_async.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteDocument +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_DeleteDocument_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_delete_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteDocumentRequest( + name="name_value", + ) + + # Make the request + await client.delete_document(request=request) + + +# [END generativelanguage_v1alpha_generated_RetrieverService_DeleteDocument_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_document_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_document_sync.py new file mode 100644 index 000000000000..3f55752e5ccd --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_delete_document_sync.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for DeleteDocument +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_DeleteDocument_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_delete_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.DeleteDocumentRequest( + name="name_value", + ) + + # Make the request + client.delete_document(request=request) + + +# [END generativelanguage_v1alpha_generated_RetrieverService_DeleteDocument_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_chunk_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_chunk_async.py new file mode 100644 index 000000000000..70b201278e7a --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_chunk_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetChunk +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_GetChunk_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_get_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetChunkRequest( + name="name_value", + ) + + # Make the request + response = await client.get_chunk(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_GetChunk_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_chunk_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_chunk_sync.py new file mode 100644 index 000000000000..dbd3db21212a --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_chunk_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetChunk +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_GetChunk_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_get_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetChunkRequest( + name="name_value", + ) + + # Make the request + response = client.get_chunk(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_GetChunk_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_corpus_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_corpus_async.py new file mode 100644 index 000000000000..5dfb52b613d5 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_corpus_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetCorpus +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_GetCorpus_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_get_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetCorpusRequest( + name="name_value", + ) + + # Make the request + response = await client.get_corpus(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_GetCorpus_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_corpus_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_corpus_sync.py new file mode 100644 index 000000000000..6fe326141045 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_corpus_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetCorpus +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_GetCorpus_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_get_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetCorpusRequest( + name="name_value", + ) + + # Make the request + response = client.get_corpus(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_GetCorpus_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_document_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_document_async.py new file mode 100644 index 000000000000..c100b61c41cf --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_document_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetDocument +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_GetDocument_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_get_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetDocumentRequest( + name="name_value", + ) + + # Make the request + response = await client.get_document(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_GetDocument_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_document_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_document_sync.py new file mode 100644 index 000000000000..1bca4cad2d00 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_get_document_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GetDocument +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_GetDocument_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_get_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.GetDocumentRequest( + name="name_value", + ) + + # Make the request + response = client.get_document(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_GetDocument_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_chunks_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_chunks_async.py new file mode 100644 index 000000000000..817ee4bf01d1 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_chunks_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListChunks +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_ListChunks_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_list_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListChunksRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_chunks(request=request) + + # Handle the response + async for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_ListChunks_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_chunks_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_chunks_sync.py new file mode 100644 index 000000000000..a30732170055 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_chunks_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListChunks +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_ListChunks_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_list_chunks(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListChunksRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_chunks(request=request) + + # Handle the response + for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_ListChunks_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_corpora_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_corpora_async.py new file mode 100644 index 000000000000..2a37061a4636 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_corpora_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListCorpora +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_ListCorpora_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_list_corpora(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListCorporaRequest( + ) + + # Make the request + page_result = client.list_corpora(request=request) + + # Handle the response + async for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_ListCorpora_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_corpora_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_corpora_sync.py new file mode 100644 index 000000000000..b33c84c3e88f --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_corpora_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListCorpora +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_ListCorpora_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_list_corpora(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListCorporaRequest( + ) + + # Make the request + page_result = client.list_corpora(request=request) + + # Handle the response + for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_ListCorpora_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_documents_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_documents_async.py new file mode 100644 index 000000000000..7cff65392bf8 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_documents_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListDocuments +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_ListDocuments_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_list_documents(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListDocumentsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_documents(request=request) + + # Handle the response + async for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_ListDocuments_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_documents_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_documents_sync.py new file mode 100644 index 000000000000..934bd69738a4 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_list_documents_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for ListDocuments +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_ListDocuments_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_list_documents(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.ListDocumentsRequest( + parent="parent_value", + ) + + # Make the request + page_result = client.list_documents(request=request) + + # Handle the response + for response in page_result: + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_ListDocuments_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_corpus_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_corpus_async.py new file mode 100644 index 000000000000..ad55c5da5caf --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_corpus_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for QueryCorpus +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_QueryCorpus_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_query_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.QueryCorpusRequest( + name="name_value", + query="query_value", + ) + + # Make the request + response = await client.query_corpus(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_QueryCorpus_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_corpus_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_corpus_sync.py new file mode 100644 index 000000000000..52f35ca4ad01 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_corpus_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for QueryCorpus +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_QueryCorpus_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_query_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.QueryCorpusRequest( + name="name_value", + query="query_value", + ) + + # Make the request + response = client.query_corpus(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_QueryCorpus_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_document_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_document_async.py new file mode 100644 index 000000000000..908be016cc95 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_document_async.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for QueryDocument +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_QueryDocument_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_query_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.QueryDocumentRequest( + name="name_value", + query="query_value", + ) + + # Make the request + response = await client.query_document(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_QueryDocument_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_document_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_document_sync.py new file mode 100644 index 000000000000..5d3b84398329 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_query_document_sync.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for QueryDocument +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_QueryDocument_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_query_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.QueryDocumentRequest( + name="name_value", + query="query_value", + ) + + # Make the request + response = client.query_document(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_QueryDocument_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_chunk_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_chunk_async.py new file mode 100644 index 000000000000..62395c93331a --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_chunk_async.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateChunk +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_UpdateChunk_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_update_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + chunk = generativelanguage_v1alpha.Chunk() + chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.UpdateChunkRequest( + chunk=chunk, + ) + + # Make the request + response = await client.update_chunk(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_UpdateChunk_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_chunk_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_chunk_sync.py new file mode 100644 index 000000000000..c542e8e2e302 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_chunk_sync.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateChunk +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_UpdateChunk_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_update_chunk(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + chunk = generativelanguage_v1alpha.Chunk() + chunk.data.string_value = "string_value_value" + + request = generativelanguage_v1alpha.UpdateChunkRequest( + chunk=chunk, + ) + + # Make the request + response = client.update_chunk(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_UpdateChunk_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_corpus_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_corpus_async.py new file mode 100644 index 000000000000..6b29f4119273 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_corpus_async.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateCorpus +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_UpdateCorpus_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_update_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateCorpusRequest( + ) + + # Make the request + response = await client.update_corpus(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_UpdateCorpus_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_corpus_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_corpus_sync.py new file mode 100644 index 000000000000..a364eff802db --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_corpus_sync.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateCorpus +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_UpdateCorpus_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_update_corpus(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateCorpusRequest( + ) + + # Make the request + response = client.update_corpus(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_UpdateCorpus_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_document_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_document_async.py new file mode 100644 index 000000000000..a532c772f723 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_document_async.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateDocument +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_UpdateDocument_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_update_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateDocumentRequest( + ) + + # Make the request + response = await client.update_document(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_UpdateDocument_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_document_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_document_sync.py new file mode 100644 index 000000000000..0037a5f02359 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_retriever_service_update_document_sync.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for UpdateDocument +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_RetrieverService_UpdateDocument_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_update_document(): + # Create a client + client = generativelanguage_v1alpha.RetrieverServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.UpdateDocumentRequest( + ) + + # Make the request + response = client.update_document(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_RetrieverService_UpdateDocument_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_batch_embed_text_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_batch_embed_text_async.py new file mode 100644 index 000000000000..3a27fec77706 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_batch_embed_text_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BatchEmbedText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_TextService_BatchEmbedText_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_batch_embed_text(): + # Create a client + client = generativelanguage_v1alpha.TextServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.BatchEmbedTextRequest( + model="model_value", + ) + + # Make the request + response = await client.batch_embed_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_TextService_BatchEmbedText_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_batch_embed_text_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_batch_embed_text_sync.py new file mode 100644 index 000000000000..c01f0f7057e9 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_batch_embed_text_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for BatchEmbedText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_TextService_BatchEmbedText_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_batch_embed_text(): + # Create a client + client = generativelanguage_v1alpha.TextServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.BatchEmbedTextRequest( + model="model_value", + ) + + # Make the request + response = client.batch_embed_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_TextService_BatchEmbedText_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_count_text_tokens_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_count_text_tokens_async.py new file mode 100644 index 000000000000..499086b4372b --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_count_text_tokens_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CountTextTokens +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_TextService_CountTextTokens_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_count_text_tokens(): + # Create a client + client = generativelanguage_v1alpha.TextServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1alpha.CountTextTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.count_text_tokens(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_TextService_CountTextTokens_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_count_text_tokens_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_count_text_tokens_sync.py new file mode 100644 index 000000000000..933c2887e538 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_count_text_tokens_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for CountTextTokens +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_TextService_CountTextTokens_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_count_text_tokens(): + # Create a client + client = generativelanguage_v1alpha.TextServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1alpha.CountTextTokensRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.count_text_tokens(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_TextService_CountTextTokens_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_embed_text_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_embed_text_async.py new file mode 100644 index 000000000000..2e72d0bb0d20 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_embed_text_async.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for EmbedText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_TextService_EmbedText_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_embed_text(): + # Create a client + client = generativelanguage_v1alpha.TextServiceAsyncClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.EmbedTextRequest( + model="model_value", + ) + + # Make the request + response = await client.embed_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_TextService_EmbedText_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_embed_text_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_embed_text_sync.py new file mode 100644 index 000000000000..fe370d4fa9e0 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_embed_text_sync.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for EmbedText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_TextService_EmbedText_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_embed_text(): + # Create a client + client = generativelanguage_v1alpha.TextServiceClient() + + # Initialize request argument(s) + request = generativelanguage_v1alpha.EmbedTextRequest( + model="model_value", + ) + + # Make the request + response = client.embed_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_TextService_EmbedText_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_generate_text_async.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_generate_text_async.py new file mode 100644 index 000000000000..f7d6e946807a --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_generate_text_async.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_TextService_GenerateText_async] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +async def sample_generate_text(): + # Create a client + client = generativelanguage_v1alpha.TextServiceAsyncClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1alpha.GenerateTextRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = await client.generate_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_TextService_GenerateText_async] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_generate_text_sync.py b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_generate_text_sync.py new file mode 100644 index 000000000000..db71efe2b5cd --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/generativelanguage_v1alpha_generated_text_service_generate_text_sync.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Generated code. DO NOT EDIT! +# +# Snippet for GenerateText +# NOTE: This snippet has been automatically generated for illustrative purposes only. +# It may require modifications to work in your environment. + +# To install the latest published package dependency, execute the following: +# python3 -m pip install google-ai-generativelanguage + + +# [START generativelanguage_v1alpha_generated_TextService_GenerateText_sync] +# This snippet has been automatically generated and should be regarded as a +# code template only. +# It will require modifications to work: +# - It may require correct/in-range values for request initialization. +# - It may require specifying regional endpoints when creating the service +# client as shown in: +# https://googleapis.dev/python/google-api-core/latest/client_options.html +from google.ai import generativelanguage_v1alpha + + +def sample_generate_text(): + # Create a client + client = generativelanguage_v1alpha.TextServiceClient() + + # Initialize request argument(s) + prompt = generativelanguage_v1alpha.TextPrompt() + prompt.text = "text_value" + + request = generativelanguage_v1alpha.GenerateTextRequest( + model="model_value", + prompt=prompt, + ) + + # Make the request + response = client.generate_text(request=request) + + # Handle the response + print(response) + +# [END generativelanguage_v1alpha_generated_TextService_GenerateText_sync] diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1.json b/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1.json index 32123624e9e5..87bcca60fb53 100644 --- a/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1.json +++ b/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-ai-generativelanguage", - "version": "0.6.14" + "version": "0.1.0" }, "snippets": [ { diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1alpha.json b/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1alpha.json new file mode 100644 index 000000000000..a683a40c49c4 --- /dev/null +++ b/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1alpha.json @@ -0,0 +1,9183 @@ +{ + "clientLibrary": { + "apis": [ + { + "id": "google.ai.generativelanguage.v1alpha", + "version": "v1alpha" + } + ], + "language": "PYTHON", + "name": "google-ai-generativelanguage", + "version": "0.1.0" + }, + "snippets": [ + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceAsyncClient", + "shortName": "CacheServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceAsyncClient.create_cached_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService.CreateCachedContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService", + "shortName": "CacheService" + }, + "shortName": "CreateCachedContent" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CreateCachedContentRequest" + }, + { + "name": "cached_content", + "type": "google.ai.generativelanguage_v1alpha.types.CachedContent" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.CachedContent", + "shortName": "create_cached_content" + }, + "description": "Sample for CreateCachedContent", + "file": "generativelanguage_v1alpha_generated_cache_service_create_cached_content_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_CacheService_CreateCachedContent_async", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_cache_service_create_cached_content_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceClient", + "shortName": "CacheServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceClient.create_cached_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService.CreateCachedContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService", + "shortName": "CacheService" + }, + "shortName": "CreateCachedContent" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CreateCachedContentRequest" + }, + { + "name": "cached_content", + "type": "google.ai.generativelanguage_v1alpha.types.CachedContent" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.CachedContent", + "shortName": "create_cached_content" + }, + "description": "Sample for CreateCachedContent", + "file": "generativelanguage_v1alpha_generated_cache_service_create_cached_content_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_CacheService_CreateCachedContent_sync", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_cache_service_create_cached_content_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceAsyncClient", + "shortName": "CacheServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceAsyncClient.delete_cached_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService.DeleteCachedContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService", + "shortName": "CacheService" + }, + "shortName": "DeleteCachedContent" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.DeleteCachedContentRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "delete_cached_content" + }, + "description": "Sample for DeleteCachedContent", + "file": "generativelanguage_v1alpha_generated_cache_service_delete_cached_content_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_CacheService_DeleteCachedContent_async", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_cache_service_delete_cached_content_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceClient", + "shortName": "CacheServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceClient.delete_cached_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService.DeleteCachedContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService", + "shortName": "CacheService" + }, + "shortName": "DeleteCachedContent" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.DeleteCachedContentRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "delete_cached_content" + }, + "description": "Sample for DeleteCachedContent", + "file": "generativelanguage_v1alpha_generated_cache_service_delete_cached_content_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_CacheService_DeleteCachedContent_sync", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_cache_service_delete_cached_content_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceAsyncClient", + "shortName": "CacheServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceAsyncClient.get_cached_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService.GetCachedContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService", + "shortName": "CacheService" + }, + "shortName": "GetCachedContent" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetCachedContentRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.CachedContent", + "shortName": "get_cached_content" + }, + "description": "Sample for GetCachedContent", + "file": "generativelanguage_v1alpha_generated_cache_service_get_cached_content_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_CacheService_GetCachedContent_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_cache_service_get_cached_content_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceClient", + "shortName": "CacheServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceClient.get_cached_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService.GetCachedContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService", + "shortName": "CacheService" + }, + "shortName": "GetCachedContent" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetCachedContentRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.CachedContent", + "shortName": "get_cached_content" + }, + "description": "Sample for GetCachedContent", + "file": "generativelanguage_v1alpha_generated_cache_service_get_cached_content_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_CacheService_GetCachedContent_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_cache_service_get_cached_content_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceAsyncClient", + "shortName": "CacheServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceAsyncClient.list_cached_contents", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService.ListCachedContents", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService", + "shortName": "CacheService" + }, + "shortName": "ListCachedContents" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListCachedContentsRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.cache_service.pagers.ListCachedContentsAsyncPager", + "shortName": "list_cached_contents" + }, + "description": "Sample for ListCachedContents", + "file": "generativelanguage_v1alpha_generated_cache_service_list_cached_contents_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_CacheService_ListCachedContents_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_cache_service_list_cached_contents_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceClient", + "shortName": "CacheServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceClient.list_cached_contents", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService.ListCachedContents", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService", + "shortName": "CacheService" + }, + "shortName": "ListCachedContents" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListCachedContentsRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.cache_service.pagers.ListCachedContentsPager", + "shortName": "list_cached_contents" + }, + "description": "Sample for ListCachedContents", + "file": "generativelanguage_v1alpha_generated_cache_service_list_cached_contents_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_CacheService_ListCachedContents_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_cache_service_list_cached_contents_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceAsyncClient", + "shortName": "CacheServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceAsyncClient.update_cached_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService.UpdateCachedContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService", + "shortName": "CacheService" + }, + "shortName": "UpdateCachedContent" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.UpdateCachedContentRequest" + }, + { + "name": "cached_content", + "type": "google.ai.generativelanguage_v1alpha.types.CachedContent" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.CachedContent", + "shortName": "update_cached_content" + }, + "description": "Sample for UpdateCachedContent", + "file": "generativelanguage_v1alpha_generated_cache_service_update_cached_content_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_CacheService_UpdateCachedContent_async", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_cache_service_update_cached_content_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceClient", + "shortName": "CacheServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.CacheServiceClient.update_cached_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService.UpdateCachedContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.CacheService", + "shortName": "CacheService" + }, + "shortName": "UpdateCachedContent" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.UpdateCachedContentRequest" + }, + { + "name": "cached_content", + "type": "google.ai.generativelanguage_v1alpha.types.CachedContent" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.CachedContent", + "shortName": "update_cached_content" + }, + "description": "Sample for UpdateCachedContent", + "file": "generativelanguage_v1alpha_generated_cache_service_update_cached_content_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_CacheService_UpdateCachedContent_sync", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_cache_service_update_cached_content_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.DiscussServiceAsyncClient", + "shortName": "DiscussServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.DiscussServiceAsyncClient.count_message_tokens", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.DiscussService.CountMessageTokens", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.DiscussService", + "shortName": "DiscussService" + }, + "shortName": "CountMessageTokens" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CountMessageTokensRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1alpha.types.MessagePrompt" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.CountMessageTokensResponse", + "shortName": "count_message_tokens" + }, + "description": "Sample for CountMessageTokens", + "file": "generativelanguage_v1alpha_generated_discuss_service_count_message_tokens_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_DiscussService_CountMessageTokens_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_discuss_service_count_message_tokens_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.DiscussServiceClient", + "shortName": "DiscussServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.DiscussServiceClient.count_message_tokens", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.DiscussService.CountMessageTokens", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.DiscussService", + "shortName": "DiscussService" + }, + "shortName": "CountMessageTokens" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CountMessageTokensRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1alpha.types.MessagePrompt" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.CountMessageTokensResponse", + "shortName": "count_message_tokens" + }, + "description": "Sample for CountMessageTokens", + "file": "generativelanguage_v1alpha_generated_discuss_service_count_message_tokens_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_DiscussService_CountMessageTokens_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_discuss_service_count_message_tokens_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.DiscussServiceAsyncClient", + "shortName": "DiscussServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.DiscussServiceAsyncClient.generate_message", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.DiscussService.GenerateMessage", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.DiscussService", + "shortName": "DiscussService" + }, + "shortName": "GenerateMessage" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GenerateMessageRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1alpha.types.MessagePrompt" + }, + { + "name": "temperature", + "type": "float" + }, + { + "name": "candidate_count", + "type": "int" + }, + { + "name": "top_p", + "type": "float" + }, + { + "name": "top_k", + "type": "int" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.GenerateMessageResponse", + "shortName": "generate_message" + }, + "description": "Sample for GenerateMessage", + "file": "generativelanguage_v1alpha_generated_discuss_service_generate_message_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_DiscussService_GenerateMessage_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_discuss_service_generate_message_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.DiscussServiceClient", + "shortName": "DiscussServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.DiscussServiceClient.generate_message", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.DiscussService.GenerateMessage", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.DiscussService", + "shortName": "DiscussService" + }, + "shortName": "GenerateMessage" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GenerateMessageRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1alpha.types.MessagePrompt" + }, + { + "name": "temperature", + "type": "float" + }, + { + "name": "candidate_count", + "type": "int" + }, + { + "name": "top_p", + "type": "float" + }, + { + "name": "top_k", + "type": "int" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.GenerateMessageResponse", + "shortName": "generate_message" + }, + "description": "Sample for GenerateMessage", + "file": "generativelanguage_v1alpha_generated_discuss_service_generate_message_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_DiscussService_GenerateMessage_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_discuss_service_generate_message_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceAsyncClient", + "shortName": "FileServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceAsyncClient.create_file", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService.CreateFile", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService", + "shortName": "FileService" + }, + "shortName": "CreateFile" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CreateFileRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.CreateFileResponse", + "shortName": "create_file" + }, + "description": "Sample for CreateFile", + "file": "generativelanguage_v1alpha_generated_file_service_create_file_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_FileService_CreateFile_async", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_file_service_create_file_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceClient", + "shortName": "FileServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceClient.create_file", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService.CreateFile", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService", + "shortName": "FileService" + }, + "shortName": "CreateFile" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CreateFileRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.CreateFileResponse", + "shortName": "create_file" + }, + "description": "Sample for CreateFile", + "file": "generativelanguage_v1alpha_generated_file_service_create_file_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_FileService_CreateFile_sync", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_file_service_create_file_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceAsyncClient", + "shortName": "FileServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceAsyncClient.delete_file", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService.DeleteFile", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService", + "shortName": "FileService" + }, + "shortName": "DeleteFile" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.DeleteFileRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "delete_file" + }, + "description": "Sample for DeleteFile", + "file": "generativelanguage_v1alpha_generated_file_service_delete_file_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_FileService_DeleteFile_async", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_file_service_delete_file_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceClient", + "shortName": "FileServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceClient.delete_file", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService.DeleteFile", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService", + "shortName": "FileService" + }, + "shortName": "DeleteFile" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.DeleteFileRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "delete_file" + }, + "description": "Sample for DeleteFile", + "file": "generativelanguage_v1alpha_generated_file_service_delete_file_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_FileService_DeleteFile_sync", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_file_service_delete_file_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceAsyncClient", + "shortName": "FileServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceAsyncClient.get_file", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService.GetFile", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService", + "shortName": "FileService" + }, + "shortName": "GetFile" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetFileRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.File", + "shortName": "get_file" + }, + "description": "Sample for GetFile", + "file": "generativelanguage_v1alpha_generated_file_service_get_file_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_FileService_GetFile_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_file_service_get_file_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceClient", + "shortName": "FileServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceClient.get_file", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService.GetFile", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService", + "shortName": "FileService" + }, + "shortName": "GetFile" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetFileRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.File", + "shortName": "get_file" + }, + "description": "Sample for GetFile", + "file": "generativelanguage_v1alpha_generated_file_service_get_file_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_FileService_GetFile_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_file_service_get_file_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceAsyncClient", + "shortName": "FileServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceAsyncClient.list_files", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService.ListFiles", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService", + "shortName": "FileService" + }, + "shortName": "ListFiles" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListFilesRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.file_service.pagers.ListFilesAsyncPager", + "shortName": "list_files" + }, + "description": "Sample for ListFiles", + "file": "generativelanguage_v1alpha_generated_file_service_list_files_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_FileService_ListFiles_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_file_service_list_files_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceClient", + "shortName": "FileServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.FileServiceClient.list_files", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService.ListFiles", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.FileService", + "shortName": "FileService" + }, + "shortName": "ListFiles" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListFilesRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.file_service.pagers.ListFilesPager", + "shortName": "list_files" + }, + "description": "Sample for ListFiles", + "file": "generativelanguage_v1alpha_generated_file_service_list_files_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_FileService_ListFiles_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_file_service_list_files_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient", + "shortName": "GenerativeServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient.batch_embed_contents", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService.BatchEmbedContents", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "shortName": "GenerativeService" + }, + "shortName": "BatchEmbedContents" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.BatchEmbedContentsRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "requests", + "type": "MutableSequence[google.ai.generativelanguage_v1alpha.types.EmbedContentRequest]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.BatchEmbedContentsResponse", + "shortName": "batch_embed_contents" + }, + "description": "Sample for BatchEmbedContents", + "file": "generativelanguage_v1alpha_generated_generative_service_batch_embed_contents_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_GenerativeService_BatchEmbedContents_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_generative_service_batch_embed_contents_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceClient", + "shortName": "GenerativeServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceClient.batch_embed_contents", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService.BatchEmbedContents", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "shortName": "GenerativeService" + }, + "shortName": "BatchEmbedContents" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.BatchEmbedContentsRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "requests", + "type": "MutableSequence[google.ai.generativelanguage_v1alpha.types.EmbedContentRequest]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.BatchEmbedContentsResponse", + "shortName": "batch_embed_contents" + }, + "description": "Sample for BatchEmbedContents", + "file": "generativelanguage_v1alpha_generated_generative_service_batch_embed_contents_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_GenerativeService_BatchEmbedContents_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_generative_service_batch_embed_contents_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient", + "shortName": "GenerativeServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient.bidi_generate_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "shortName": "GenerativeService" + }, + "shortName": "BidiGenerateContent" + }, + "parameters": [ + { + "name": "requests", + "type": "Iterator[google.ai.generativelanguage_v1alpha.types.BidiGenerateContentClientMessage]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "Iterable[google.ai.generativelanguage_v1alpha.types.BidiGenerateContentServerMessage]", + "shortName": "bidi_generate_content" + }, + "description": "Sample for BidiGenerateContent", + "file": "generativelanguage_v1alpha_generated_generative_service_bidi_generate_content_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_GenerativeService_BidiGenerateContent_async", + "segments": [ + { + "end": 65, + "start": 27, + "type": "FULL" + }, + { + "end": 65, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 58, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 61, + "start": 59, + "type": "REQUEST_EXECUTION" + }, + { + "end": 66, + "start": 62, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_generative_service_bidi_generate_content_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceClient", + "shortName": "GenerativeServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceClient.bidi_generate_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService.BidiGenerateContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "shortName": "GenerativeService" + }, + "shortName": "BidiGenerateContent" + }, + "parameters": [ + { + "name": "requests", + "type": "Iterator[google.ai.generativelanguage_v1alpha.types.BidiGenerateContentClientMessage]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "Iterable[google.ai.generativelanguage_v1alpha.types.BidiGenerateContentServerMessage]", + "shortName": "bidi_generate_content" + }, + "description": "Sample for BidiGenerateContent", + "file": "generativelanguage_v1alpha_generated_generative_service_bidi_generate_content_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_GenerativeService_BidiGenerateContent_sync", + "segments": [ + { + "end": 65, + "start": 27, + "type": "FULL" + }, + { + "end": 65, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 58, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 61, + "start": 59, + "type": "REQUEST_EXECUTION" + }, + { + "end": 66, + "start": 62, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_generative_service_bidi_generate_content_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient", + "shortName": "GenerativeServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient.count_tokens", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService.CountTokens", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "shortName": "GenerativeService" + }, + "shortName": "CountTokens" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CountTokensRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "contents", + "type": "MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.CountTokensResponse", + "shortName": "count_tokens" + }, + "description": "Sample for CountTokens", + "file": "generativelanguage_v1alpha_generated_generative_service_count_tokens_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_GenerativeService_CountTokens_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_generative_service_count_tokens_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceClient", + "shortName": "GenerativeServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceClient.count_tokens", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService.CountTokens", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "shortName": "GenerativeService" + }, + "shortName": "CountTokens" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CountTokensRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "contents", + "type": "MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.CountTokensResponse", + "shortName": "count_tokens" + }, + "description": "Sample for CountTokens", + "file": "generativelanguage_v1alpha_generated_generative_service_count_tokens_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_GenerativeService_CountTokens_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_generative_service_count_tokens_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient", + "shortName": "GenerativeServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient.embed_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService.EmbedContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "shortName": "GenerativeService" + }, + "shortName": "EmbedContent" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.EmbedContentRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "content", + "type": "google.ai.generativelanguage_v1alpha.types.Content" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.EmbedContentResponse", + "shortName": "embed_content" + }, + "description": "Sample for EmbedContent", + "file": "generativelanguage_v1alpha_generated_generative_service_embed_content_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_GenerativeService_EmbedContent_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_generative_service_embed_content_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceClient", + "shortName": "GenerativeServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceClient.embed_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService.EmbedContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "shortName": "GenerativeService" + }, + "shortName": "EmbedContent" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.EmbedContentRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "content", + "type": "google.ai.generativelanguage_v1alpha.types.Content" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.EmbedContentResponse", + "shortName": "embed_content" + }, + "description": "Sample for EmbedContent", + "file": "generativelanguage_v1alpha_generated_generative_service_embed_content_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_GenerativeService_EmbedContent_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_generative_service_embed_content_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient", + "shortName": "GenerativeServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient.generate_answer", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService.GenerateAnswer", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "shortName": "GenerativeService" + }, + "shortName": "GenerateAnswer" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GenerateAnswerRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "contents", + "type": "MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]" + }, + { + "name": "safety_settings", + "type": "MutableSequence[google.ai.generativelanguage_v1alpha.types.SafetySetting]" + }, + { + "name": "answer_style", + "type": "google.ai.generativelanguage_v1alpha.types.GenerateAnswerRequest.AnswerStyle" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.GenerateAnswerResponse", + "shortName": "generate_answer" + }, + "description": "Sample for GenerateAnswer", + "file": "generativelanguage_v1alpha_generated_generative_service_generate_answer_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_GenerativeService_GenerateAnswer_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_generative_service_generate_answer_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceClient", + "shortName": "GenerativeServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceClient.generate_answer", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService.GenerateAnswer", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "shortName": "GenerativeService" + }, + "shortName": "GenerateAnswer" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GenerateAnswerRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "contents", + "type": "MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]" + }, + { + "name": "safety_settings", + "type": "MutableSequence[google.ai.generativelanguage_v1alpha.types.SafetySetting]" + }, + { + "name": "answer_style", + "type": "google.ai.generativelanguage_v1alpha.types.GenerateAnswerRequest.AnswerStyle" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.GenerateAnswerResponse", + "shortName": "generate_answer" + }, + "description": "Sample for GenerateAnswer", + "file": "generativelanguage_v1alpha_generated_generative_service_generate_answer_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_GenerativeService_GenerateAnswer_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_generative_service_generate_answer_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient", + "shortName": "GenerativeServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient.generate_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService.GenerateContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "shortName": "GenerativeService" + }, + "shortName": "GenerateContent" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GenerateContentRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "contents", + "type": "MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.GenerateContentResponse", + "shortName": "generate_content" + }, + "description": "Sample for GenerateContent", + "file": "generativelanguage_v1alpha_generated_generative_service_generate_content_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_GenerativeService_GenerateContent_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_generative_service_generate_content_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceClient", + "shortName": "GenerativeServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceClient.generate_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService.GenerateContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "shortName": "GenerativeService" + }, + "shortName": "GenerateContent" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GenerateContentRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "contents", + "type": "MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.GenerateContentResponse", + "shortName": "generate_content" + }, + "description": "Sample for GenerateContent", + "file": "generativelanguage_v1alpha_generated_generative_service_generate_content_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_GenerativeService_GenerateContent_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_generative_service_generate_content_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient", + "shortName": "GenerativeServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceAsyncClient.stream_generate_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService.StreamGenerateContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "shortName": "GenerativeService" + }, + "shortName": "StreamGenerateContent" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GenerateContentRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "contents", + "type": "MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "Iterable[google.ai.generativelanguage_v1alpha.types.GenerateContentResponse]", + "shortName": "stream_generate_content" + }, + "description": "Sample for StreamGenerateContent", + "file": "generativelanguage_v1alpha_generated_generative_service_stream_generate_content_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_GenerativeService_StreamGenerateContent_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_generative_service_stream_generate_content_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceClient", + "shortName": "GenerativeServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.GenerativeServiceClient.stream_generate_content", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService.StreamGenerateContent", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.GenerativeService", + "shortName": "GenerativeService" + }, + "shortName": "StreamGenerateContent" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GenerateContentRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "contents", + "type": "MutableSequence[google.ai.generativelanguage_v1alpha.types.Content]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "Iterable[google.ai.generativelanguage_v1alpha.types.GenerateContentResponse]", + "shortName": "stream_generate_content" + }, + "description": "Sample for StreamGenerateContent", + "file": "generativelanguage_v1alpha_generated_generative_service_stream_generate_content_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_GenerativeService_StreamGenerateContent_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_generative_service_stream_generate_content_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient.create_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService.CreateTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService", + "shortName": "ModelService" + }, + "shortName": "CreateTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CreateTunedModelRequest" + }, + { + "name": "tuned_model", + "type": "google.ai.generativelanguage_v1alpha.types.TunedModel" + }, + { + "name": "tuned_model_id", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.api_core.operation_async.AsyncOperation", + "shortName": "create_tuned_model" + }, + "description": "Sample for CreateTunedModel", + "file": "generativelanguage_v1alpha_generated_model_service_create_tuned_model_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_ModelService_CreateTunedModel_async", + "segments": [ + { + "end": 54, + "start": 27, + "type": "FULL" + }, + { + "end": 54, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 51, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 55, + "start": 52, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_model_service_create_tuned_model_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceClient.create_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService.CreateTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService", + "shortName": "ModelService" + }, + "shortName": "CreateTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CreateTunedModelRequest" + }, + { + "name": "tuned_model", + "type": "google.ai.generativelanguage_v1alpha.types.TunedModel" + }, + { + "name": "tuned_model_id", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.api_core.operation.Operation", + "shortName": "create_tuned_model" + }, + "description": "Sample for CreateTunedModel", + "file": "generativelanguage_v1alpha_generated_model_service_create_tuned_model_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_ModelService_CreateTunedModel_sync", + "segments": [ + { + "end": 54, + "start": 27, + "type": "FULL" + }, + { + "end": 54, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 51, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 55, + "start": 52, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_model_service_create_tuned_model_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient.delete_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService.DeleteTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService", + "shortName": "ModelService" + }, + "shortName": "DeleteTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.DeleteTunedModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "delete_tuned_model" + }, + "description": "Sample for DeleteTunedModel", + "file": "generativelanguage_v1alpha_generated_model_service_delete_tuned_model_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_ModelService_DeleteTunedModel_async", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_model_service_delete_tuned_model_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceClient.delete_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService.DeleteTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService", + "shortName": "ModelService" + }, + "shortName": "DeleteTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.DeleteTunedModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "delete_tuned_model" + }, + "description": "Sample for DeleteTunedModel", + "file": "generativelanguage_v1alpha_generated_model_service_delete_tuned_model_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_ModelService_DeleteTunedModel_sync", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_model_service_delete_tuned_model_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient.get_model", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService.GetModel", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService", + "shortName": "ModelService" + }, + "shortName": "GetModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Model", + "shortName": "get_model" + }, + "description": "Sample for GetModel", + "file": "generativelanguage_v1alpha_generated_model_service_get_model_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_ModelService_GetModel_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_model_service_get_model_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceClient.get_model", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService.GetModel", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService", + "shortName": "ModelService" + }, + "shortName": "GetModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Model", + "shortName": "get_model" + }, + "description": "Sample for GetModel", + "file": "generativelanguage_v1alpha_generated_model_service_get_model_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_ModelService_GetModel_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_model_service_get_model_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient.get_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService.GetTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService", + "shortName": "ModelService" + }, + "shortName": "GetTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetTunedModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.TunedModel", + "shortName": "get_tuned_model" + }, + "description": "Sample for GetTunedModel", + "file": "generativelanguage_v1alpha_generated_model_service_get_tuned_model_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_ModelService_GetTunedModel_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_model_service_get_tuned_model_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceClient.get_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService.GetTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService", + "shortName": "ModelService" + }, + "shortName": "GetTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetTunedModelRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.TunedModel", + "shortName": "get_tuned_model" + }, + "description": "Sample for GetTunedModel", + "file": "generativelanguage_v1alpha_generated_model_service_get_tuned_model_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_ModelService_GetTunedModel_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_model_service_get_tuned_model_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient.list_models", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService.ListModels", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService", + "shortName": "ModelService" + }, + "shortName": "ListModels" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListModelsRequest" + }, + { + "name": "page_size", + "type": "int" + }, + { + "name": "page_token", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.model_service.pagers.ListModelsAsyncPager", + "shortName": "list_models" + }, + "description": "Sample for ListModels", + "file": "generativelanguage_v1alpha_generated_model_service_list_models_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_ModelService_ListModels_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_model_service_list_models_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceClient.list_models", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService.ListModels", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService", + "shortName": "ModelService" + }, + "shortName": "ListModels" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListModelsRequest" + }, + { + "name": "page_size", + "type": "int" + }, + { + "name": "page_token", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.model_service.pagers.ListModelsPager", + "shortName": "list_models" + }, + "description": "Sample for ListModels", + "file": "generativelanguage_v1alpha_generated_model_service_list_models_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_ModelService_ListModels_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_model_service_list_models_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient.list_tuned_models", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService.ListTunedModels", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService", + "shortName": "ModelService" + }, + "shortName": "ListTunedModels" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListTunedModelsRequest" + }, + { + "name": "page_size", + "type": "int" + }, + { + "name": "page_token", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.model_service.pagers.ListTunedModelsAsyncPager", + "shortName": "list_tuned_models" + }, + "description": "Sample for ListTunedModels", + "file": "generativelanguage_v1alpha_generated_model_service_list_tuned_models_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_ModelService_ListTunedModels_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_model_service_list_tuned_models_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceClient.list_tuned_models", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService.ListTunedModels", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService", + "shortName": "ModelService" + }, + "shortName": "ListTunedModels" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListTunedModelsRequest" + }, + { + "name": "page_size", + "type": "int" + }, + { + "name": "page_token", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.model_service.pagers.ListTunedModelsPager", + "shortName": "list_tuned_models" + }, + "description": "Sample for ListTunedModels", + "file": "generativelanguage_v1alpha_generated_model_service_list_tuned_models_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_ModelService_ListTunedModels_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_model_service_list_tuned_models_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient", + "shortName": "ModelServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceAsyncClient.update_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService.UpdateTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService", + "shortName": "ModelService" + }, + "shortName": "UpdateTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.UpdateTunedModelRequest" + }, + { + "name": "tuned_model", + "type": "google.ai.generativelanguage_v1alpha.types.TunedModel" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.TunedModel", + "shortName": "update_tuned_model" + }, + "description": "Sample for UpdateTunedModel", + "file": "generativelanguage_v1alpha_generated_model_service_update_tuned_model_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_ModelService_UpdateTunedModel_async", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_model_service_update_tuned_model_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceClient", + "shortName": "ModelServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.ModelServiceClient.update_tuned_model", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService.UpdateTunedModel", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.ModelService", + "shortName": "ModelService" + }, + "shortName": "UpdateTunedModel" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.UpdateTunedModelRequest" + }, + { + "name": "tuned_model", + "type": "google.ai.generativelanguage_v1alpha.types.TunedModel" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.TunedModel", + "shortName": "update_tuned_model" + }, + "description": "Sample for UpdateTunedModel", + "file": "generativelanguage_v1alpha_generated_model_service_update_tuned_model_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_ModelService_UpdateTunedModel_sync", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_model_service_update_tuned_model_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceAsyncClient", + "shortName": "PermissionServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceAsyncClient.create_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService.CreatePermission", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "CreatePermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CreatePermissionRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "permission", + "type": "google.ai.generativelanguage_v1alpha.types.Permission" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Permission", + "shortName": "create_permission" + }, + "description": "Sample for CreatePermission", + "file": "generativelanguage_v1alpha_generated_permission_service_create_permission_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_PermissionService_CreatePermission_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_permission_service_create_permission_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceClient", + "shortName": "PermissionServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceClient.create_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService.CreatePermission", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "CreatePermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CreatePermissionRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "permission", + "type": "google.ai.generativelanguage_v1alpha.types.Permission" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Permission", + "shortName": "create_permission" + }, + "description": "Sample for CreatePermission", + "file": "generativelanguage_v1alpha_generated_permission_service_create_permission_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_PermissionService_CreatePermission_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_permission_service_create_permission_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceAsyncClient", + "shortName": "PermissionServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceAsyncClient.delete_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService.DeletePermission", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "DeletePermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.DeletePermissionRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "delete_permission" + }, + "description": "Sample for DeletePermission", + "file": "generativelanguage_v1alpha_generated_permission_service_delete_permission_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_PermissionService_DeletePermission_async", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_permission_service_delete_permission_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceClient", + "shortName": "PermissionServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceClient.delete_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService.DeletePermission", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "DeletePermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.DeletePermissionRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "delete_permission" + }, + "description": "Sample for DeletePermission", + "file": "generativelanguage_v1alpha_generated_permission_service_delete_permission_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_PermissionService_DeletePermission_sync", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_permission_service_delete_permission_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceAsyncClient", + "shortName": "PermissionServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceAsyncClient.get_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService.GetPermission", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "GetPermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetPermissionRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Permission", + "shortName": "get_permission" + }, + "description": "Sample for GetPermission", + "file": "generativelanguage_v1alpha_generated_permission_service_get_permission_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_PermissionService_GetPermission_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_permission_service_get_permission_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceClient", + "shortName": "PermissionServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceClient.get_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService.GetPermission", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "GetPermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetPermissionRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Permission", + "shortName": "get_permission" + }, + "description": "Sample for GetPermission", + "file": "generativelanguage_v1alpha_generated_permission_service_get_permission_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_PermissionService_GetPermission_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_permission_service_get_permission_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceAsyncClient", + "shortName": "PermissionServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceAsyncClient.list_permissions", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService.ListPermissions", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "ListPermissions" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListPermissionsRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.permission_service.pagers.ListPermissionsAsyncPager", + "shortName": "list_permissions" + }, + "description": "Sample for ListPermissions", + "file": "generativelanguage_v1alpha_generated_permission_service_list_permissions_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_PermissionService_ListPermissions_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_permission_service_list_permissions_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceClient", + "shortName": "PermissionServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceClient.list_permissions", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService.ListPermissions", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "ListPermissions" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListPermissionsRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.permission_service.pagers.ListPermissionsPager", + "shortName": "list_permissions" + }, + "description": "Sample for ListPermissions", + "file": "generativelanguage_v1alpha_generated_permission_service_list_permissions_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_PermissionService_ListPermissions_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_permission_service_list_permissions_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceAsyncClient", + "shortName": "PermissionServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceAsyncClient.transfer_ownership", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService.TransferOwnership", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "TransferOwnership" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.TransferOwnershipRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.TransferOwnershipResponse", + "shortName": "transfer_ownership" + }, + "description": "Sample for TransferOwnership", + "file": "generativelanguage_v1alpha_generated_permission_service_transfer_ownership_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_PermissionService_TransferOwnership_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_permission_service_transfer_ownership_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceClient", + "shortName": "PermissionServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceClient.transfer_ownership", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService.TransferOwnership", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "TransferOwnership" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.TransferOwnershipRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.TransferOwnershipResponse", + "shortName": "transfer_ownership" + }, + "description": "Sample for TransferOwnership", + "file": "generativelanguage_v1alpha_generated_permission_service_transfer_ownership_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_PermissionService_TransferOwnership_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_permission_service_transfer_ownership_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceAsyncClient", + "shortName": "PermissionServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceAsyncClient.update_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService.UpdatePermission", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "UpdatePermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.UpdatePermissionRequest" + }, + { + "name": "permission", + "type": "google.ai.generativelanguage_v1alpha.types.Permission" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Permission", + "shortName": "update_permission" + }, + "description": "Sample for UpdatePermission", + "file": "generativelanguage_v1alpha_generated_permission_service_update_permission_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_PermissionService_UpdatePermission_async", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_permission_service_update_permission_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceClient", + "shortName": "PermissionServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.PermissionServiceClient.update_permission", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService.UpdatePermission", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.PermissionService", + "shortName": "PermissionService" + }, + "shortName": "UpdatePermission" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.UpdatePermissionRequest" + }, + { + "name": "permission", + "type": "google.ai.generativelanguage_v1alpha.types.Permission" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Permission", + "shortName": "update_permission" + }, + "description": "Sample for UpdatePermission", + "file": "generativelanguage_v1alpha_generated_permission_service_update_permission_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_PermissionService_UpdatePermission_sync", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_permission_service_update_permission_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.PredictionServiceAsyncClient", + "shortName": "PredictionServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.PredictionServiceAsyncClient.predict", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.PredictionService.Predict", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.PredictionService", + "shortName": "PredictionService" + }, + "shortName": "Predict" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.PredictRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "instances", + "type": "MutableSequence[google.protobuf.struct_pb2.Value]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.PredictResponse", + "shortName": "predict" + }, + "description": "Sample for Predict", + "file": "generativelanguage_v1alpha_generated_prediction_service_predict_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_PredictionService_Predict_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_prediction_service_predict_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.PredictionServiceClient", + "shortName": "PredictionServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.PredictionServiceClient.predict", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.PredictionService.Predict", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.PredictionService", + "shortName": "PredictionService" + }, + "shortName": "Predict" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.PredictRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "instances", + "type": "MutableSequence[google.protobuf.struct_pb2.Value]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.PredictResponse", + "shortName": "predict" + }, + "description": "Sample for Predict", + "file": "generativelanguage_v1alpha_generated_prediction_service_predict_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_PredictionService_Predict_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_prediction_service_predict_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.batch_create_chunks", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.BatchCreateChunks", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "BatchCreateChunks" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.BatchCreateChunksRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.BatchCreateChunksResponse", + "shortName": "batch_create_chunks" + }, + "description": "Sample for BatchCreateChunks", + "file": "generativelanguage_v1alpha_generated_retriever_service_batch_create_chunks_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_BatchCreateChunks_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_batch_create_chunks_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.batch_create_chunks", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.BatchCreateChunks", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "BatchCreateChunks" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.BatchCreateChunksRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.BatchCreateChunksResponse", + "shortName": "batch_create_chunks" + }, + "description": "Sample for BatchCreateChunks", + "file": "generativelanguage_v1alpha_generated_retriever_service_batch_create_chunks_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_BatchCreateChunks_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_batch_create_chunks_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.batch_delete_chunks", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.BatchDeleteChunks", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "BatchDeleteChunks" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.BatchDeleteChunksRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "batch_delete_chunks" + }, + "description": "Sample for BatchDeleteChunks", + "file": "generativelanguage_v1alpha_generated_retriever_service_batch_delete_chunks_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_BatchDeleteChunks_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 48, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 49, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_batch_delete_chunks_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.batch_delete_chunks", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.BatchDeleteChunks", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "BatchDeleteChunks" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.BatchDeleteChunksRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "batch_delete_chunks" + }, + "description": "Sample for BatchDeleteChunks", + "file": "generativelanguage_v1alpha_generated_retriever_service_batch_delete_chunks_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_BatchDeleteChunks_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 48, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 49, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_batch_delete_chunks_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.batch_update_chunks", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.BatchUpdateChunks", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "BatchUpdateChunks" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.BatchUpdateChunksRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.BatchUpdateChunksResponse", + "shortName": "batch_update_chunks" + }, + "description": "Sample for BatchUpdateChunks", + "file": "generativelanguage_v1alpha_generated_retriever_service_batch_update_chunks_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_BatchUpdateChunks_async", + "segments": [ + { + "end": 54, + "start": 27, + "type": "FULL" + }, + { + "end": 54, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 48, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 51, + "start": 49, + "type": "REQUEST_EXECUTION" + }, + { + "end": 55, + "start": 52, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_batch_update_chunks_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.batch_update_chunks", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.BatchUpdateChunks", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "BatchUpdateChunks" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.BatchUpdateChunksRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.BatchUpdateChunksResponse", + "shortName": "batch_update_chunks" + }, + "description": "Sample for BatchUpdateChunks", + "file": "generativelanguage_v1alpha_generated_retriever_service_batch_update_chunks_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_BatchUpdateChunks_sync", + "segments": [ + { + "end": 54, + "start": 27, + "type": "FULL" + }, + { + "end": 54, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 48, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 51, + "start": 49, + "type": "REQUEST_EXECUTION" + }, + { + "end": 55, + "start": 52, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_batch_update_chunks_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.create_chunk", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.CreateChunk", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "CreateChunk" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CreateChunkRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "chunk", + "type": "google.ai.generativelanguage_v1alpha.types.Chunk" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Chunk", + "shortName": "create_chunk" + }, + "description": "Sample for CreateChunk", + "file": "generativelanguage_v1alpha_generated_retriever_service_create_chunk_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_CreateChunk_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_create_chunk_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.create_chunk", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.CreateChunk", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "CreateChunk" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CreateChunkRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "chunk", + "type": "google.ai.generativelanguage_v1alpha.types.Chunk" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Chunk", + "shortName": "create_chunk" + }, + "description": "Sample for CreateChunk", + "file": "generativelanguage_v1alpha_generated_retriever_service_create_chunk_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_CreateChunk_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_create_chunk_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.create_corpus", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.CreateCorpus", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "CreateCorpus" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CreateCorpusRequest" + }, + { + "name": "corpus", + "type": "google.ai.generativelanguage_v1alpha.types.Corpus" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Corpus", + "shortName": "create_corpus" + }, + "description": "Sample for CreateCorpus", + "file": "generativelanguage_v1alpha_generated_retriever_service_create_corpus_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_CreateCorpus_async", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_create_corpus_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.create_corpus", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.CreateCorpus", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "CreateCorpus" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CreateCorpusRequest" + }, + { + "name": "corpus", + "type": "google.ai.generativelanguage_v1alpha.types.Corpus" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Corpus", + "shortName": "create_corpus" + }, + "description": "Sample for CreateCorpus", + "file": "generativelanguage_v1alpha_generated_retriever_service_create_corpus_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_CreateCorpus_sync", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_create_corpus_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.create_document", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.CreateDocument", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "CreateDocument" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CreateDocumentRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "document", + "type": "google.ai.generativelanguage_v1alpha.types.Document" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Document", + "shortName": "create_document" + }, + "description": "Sample for CreateDocument", + "file": "generativelanguage_v1alpha_generated_retriever_service_create_document_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_CreateDocument_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_create_document_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.create_document", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.CreateDocument", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "CreateDocument" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CreateDocumentRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "document", + "type": "google.ai.generativelanguage_v1alpha.types.Document" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Document", + "shortName": "create_document" + }, + "description": "Sample for CreateDocument", + "file": "generativelanguage_v1alpha_generated_retriever_service_create_document_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_CreateDocument_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_create_document_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.delete_chunk", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.DeleteChunk", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "DeleteChunk" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.DeleteChunkRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "delete_chunk" + }, + "description": "Sample for DeleteChunk", + "file": "generativelanguage_v1alpha_generated_retriever_service_delete_chunk_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_DeleteChunk_async", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_delete_chunk_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.delete_chunk", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.DeleteChunk", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "DeleteChunk" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.DeleteChunkRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "delete_chunk" + }, + "description": "Sample for DeleteChunk", + "file": "generativelanguage_v1alpha_generated_retriever_service_delete_chunk_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_DeleteChunk_sync", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_delete_chunk_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.delete_corpus", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.DeleteCorpus", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "DeleteCorpus" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.DeleteCorpusRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "delete_corpus" + }, + "description": "Sample for DeleteCorpus", + "file": "generativelanguage_v1alpha_generated_retriever_service_delete_corpus_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_DeleteCorpus_async", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_delete_corpus_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.delete_corpus", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.DeleteCorpus", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "DeleteCorpus" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.DeleteCorpusRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "delete_corpus" + }, + "description": "Sample for DeleteCorpus", + "file": "generativelanguage_v1alpha_generated_retriever_service_delete_corpus_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_DeleteCorpus_sync", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_delete_corpus_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.delete_document", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.DeleteDocument", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "DeleteDocument" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.DeleteDocumentRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "delete_document" + }, + "description": "Sample for DeleteDocument", + "file": "generativelanguage_v1alpha_generated_retriever_service_delete_document_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_DeleteDocument_async", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_delete_document_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.delete_document", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.DeleteDocument", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "DeleteDocument" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.DeleteDocumentRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "shortName": "delete_document" + }, + "description": "Sample for DeleteDocument", + "file": "generativelanguage_v1alpha_generated_retriever_service_delete_document_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_DeleteDocument_sync", + "segments": [ + { + "end": 49, + "start": 27, + "type": "FULL" + }, + { + "end": 49, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_delete_document_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.get_chunk", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.GetChunk", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "GetChunk" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetChunkRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Chunk", + "shortName": "get_chunk" + }, + "description": "Sample for GetChunk", + "file": "generativelanguage_v1alpha_generated_retriever_service_get_chunk_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_GetChunk_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_get_chunk_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.get_chunk", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.GetChunk", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "GetChunk" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetChunkRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Chunk", + "shortName": "get_chunk" + }, + "description": "Sample for GetChunk", + "file": "generativelanguage_v1alpha_generated_retriever_service_get_chunk_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_GetChunk_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_get_chunk_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.get_corpus", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.GetCorpus", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "GetCorpus" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetCorpusRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Corpus", + "shortName": "get_corpus" + }, + "description": "Sample for GetCorpus", + "file": "generativelanguage_v1alpha_generated_retriever_service_get_corpus_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_GetCorpus_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_get_corpus_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.get_corpus", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.GetCorpus", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "GetCorpus" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetCorpusRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Corpus", + "shortName": "get_corpus" + }, + "description": "Sample for GetCorpus", + "file": "generativelanguage_v1alpha_generated_retriever_service_get_corpus_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_GetCorpus_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_get_corpus_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.get_document", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.GetDocument", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "GetDocument" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetDocumentRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Document", + "shortName": "get_document" + }, + "description": "Sample for GetDocument", + "file": "generativelanguage_v1alpha_generated_retriever_service_get_document_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_GetDocument_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_get_document_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.get_document", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.GetDocument", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "GetDocument" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GetDocumentRequest" + }, + { + "name": "name", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Document", + "shortName": "get_document" + }, + "description": "Sample for GetDocument", + "file": "generativelanguage_v1alpha_generated_retriever_service_get_document_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_GetDocument_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_get_document_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.list_chunks", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.ListChunks", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "ListChunks" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListChunksRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.retriever_service.pagers.ListChunksAsyncPager", + "shortName": "list_chunks" + }, + "description": "Sample for ListChunks", + "file": "generativelanguage_v1alpha_generated_retriever_service_list_chunks_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_ListChunks_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_list_chunks_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.list_chunks", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.ListChunks", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "ListChunks" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListChunksRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.retriever_service.pagers.ListChunksPager", + "shortName": "list_chunks" + }, + "description": "Sample for ListChunks", + "file": "generativelanguage_v1alpha_generated_retriever_service_list_chunks_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_ListChunks_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_list_chunks_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.list_corpora", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.ListCorpora", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "ListCorpora" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListCorporaRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.retriever_service.pagers.ListCorporaAsyncPager", + "shortName": "list_corpora" + }, + "description": "Sample for ListCorpora", + "file": "generativelanguage_v1alpha_generated_retriever_service_list_corpora_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_ListCorpora_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_list_corpora_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.list_corpora", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.ListCorpora", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "ListCorpora" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListCorporaRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.retriever_service.pagers.ListCorporaPager", + "shortName": "list_corpora" + }, + "description": "Sample for ListCorpora", + "file": "generativelanguage_v1alpha_generated_retriever_service_list_corpora_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_ListCorpora_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_list_corpora_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.list_documents", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.ListDocuments", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "ListDocuments" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListDocumentsRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.retriever_service.pagers.ListDocumentsAsyncPager", + "shortName": "list_documents" + }, + "description": "Sample for ListDocuments", + "file": "generativelanguage_v1alpha_generated_retriever_service_list_documents_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_ListDocuments_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_list_documents_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.list_documents", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.ListDocuments", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "ListDocuments" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.ListDocumentsRequest" + }, + { + "name": "parent", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.services.retriever_service.pagers.ListDocumentsPager", + "shortName": "list_documents" + }, + "description": "Sample for ListDocuments", + "file": "generativelanguage_v1alpha_generated_retriever_service_list_documents_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_ListDocuments_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_list_documents_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.query_corpus", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.QueryCorpus", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "QueryCorpus" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.QueryCorpusRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.QueryCorpusResponse", + "shortName": "query_corpus" + }, + "description": "Sample for QueryCorpus", + "file": "generativelanguage_v1alpha_generated_retriever_service_query_corpus_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_QueryCorpus_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_query_corpus_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.query_corpus", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.QueryCorpus", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "QueryCorpus" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.QueryCorpusRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.QueryCorpusResponse", + "shortName": "query_corpus" + }, + "description": "Sample for QueryCorpus", + "file": "generativelanguage_v1alpha_generated_retriever_service_query_corpus_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_QueryCorpus_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_query_corpus_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.query_document", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.QueryDocument", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "QueryDocument" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.QueryDocumentRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.QueryDocumentResponse", + "shortName": "query_document" + }, + "description": "Sample for QueryDocument", + "file": "generativelanguage_v1alpha_generated_retriever_service_query_document_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_QueryDocument_async", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_query_document_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.query_document", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.QueryDocument", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "QueryDocument" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.QueryDocumentRequest" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.QueryDocumentResponse", + "shortName": "query_document" + }, + "description": "Sample for QueryDocument", + "file": "generativelanguage_v1alpha_generated_retriever_service_query_document_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_QueryDocument_sync", + "segments": [ + { + "end": 52, + "start": 27, + "type": "FULL" + }, + { + "end": 52, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 46, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 49, + "start": 47, + "type": "REQUEST_EXECUTION" + }, + { + "end": 53, + "start": 50, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_query_document_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.update_chunk", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.UpdateChunk", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "UpdateChunk" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.UpdateChunkRequest" + }, + { + "name": "chunk", + "type": "google.ai.generativelanguage_v1alpha.types.Chunk" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Chunk", + "shortName": "update_chunk" + }, + "description": "Sample for UpdateChunk", + "file": "generativelanguage_v1alpha_generated_retriever_service_update_chunk_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_UpdateChunk_async", + "segments": [ + { + "end": 54, + "start": 27, + "type": "FULL" + }, + { + "end": 54, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 48, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 51, + "start": 49, + "type": "REQUEST_EXECUTION" + }, + { + "end": 55, + "start": 52, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_update_chunk_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.update_chunk", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.UpdateChunk", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "UpdateChunk" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.UpdateChunkRequest" + }, + { + "name": "chunk", + "type": "google.ai.generativelanguage_v1alpha.types.Chunk" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Chunk", + "shortName": "update_chunk" + }, + "description": "Sample for UpdateChunk", + "file": "generativelanguage_v1alpha_generated_retriever_service_update_chunk_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_UpdateChunk_sync", + "segments": [ + { + "end": 54, + "start": 27, + "type": "FULL" + }, + { + "end": 54, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 48, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 51, + "start": 49, + "type": "REQUEST_EXECUTION" + }, + { + "end": 55, + "start": 52, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_update_chunk_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.update_corpus", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.UpdateCorpus", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "UpdateCorpus" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.UpdateCorpusRequest" + }, + { + "name": "corpus", + "type": "google.ai.generativelanguage_v1alpha.types.Corpus" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Corpus", + "shortName": "update_corpus" + }, + "description": "Sample for UpdateCorpus", + "file": "generativelanguage_v1alpha_generated_retriever_service_update_corpus_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_UpdateCorpus_async", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_update_corpus_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.update_corpus", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.UpdateCorpus", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "UpdateCorpus" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.UpdateCorpusRequest" + }, + { + "name": "corpus", + "type": "google.ai.generativelanguage_v1alpha.types.Corpus" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Corpus", + "shortName": "update_corpus" + }, + "description": "Sample for UpdateCorpus", + "file": "generativelanguage_v1alpha_generated_retriever_service_update_corpus_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_UpdateCorpus_sync", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_update_corpus_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient", + "shortName": "RetrieverServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceAsyncClient.update_document", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.UpdateDocument", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "UpdateDocument" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.UpdateDocumentRequest" + }, + { + "name": "document", + "type": "google.ai.generativelanguage_v1alpha.types.Document" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Document", + "shortName": "update_document" + }, + "description": "Sample for UpdateDocument", + "file": "generativelanguage_v1alpha_generated_retriever_service_update_document_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_UpdateDocument_async", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_update_document_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient", + "shortName": "RetrieverServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.RetrieverServiceClient.update_document", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService.UpdateDocument", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.RetrieverService", + "shortName": "RetrieverService" + }, + "shortName": "UpdateDocument" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.UpdateDocumentRequest" + }, + { + "name": "document", + "type": "google.ai.generativelanguage_v1alpha.types.Document" + }, + { + "name": "update_mask", + "type": "google.protobuf.field_mask_pb2.FieldMask" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.Document", + "shortName": "update_document" + }, + "description": "Sample for UpdateDocument", + "file": "generativelanguage_v1alpha_generated_retriever_service_update_document_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_RetrieverService_UpdateDocument_sync", + "segments": [ + { + "end": 50, + "start": 27, + "type": "FULL" + }, + { + "end": 50, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 44, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 47, + "start": 45, + "type": "REQUEST_EXECUTION" + }, + { + "end": 51, + "start": 48, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_retriever_service_update_document_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceAsyncClient", + "shortName": "TextServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceAsyncClient.batch_embed_text", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService.BatchEmbedText", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService", + "shortName": "TextService" + }, + "shortName": "BatchEmbedText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.BatchEmbedTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "texts", + "type": "MutableSequence[str]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.BatchEmbedTextResponse", + "shortName": "batch_embed_text" + }, + "description": "Sample for BatchEmbedText", + "file": "generativelanguage_v1alpha_generated_text_service_batch_embed_text_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_TextService_BatchEmbedText_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_text_service_batch_embed_text_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceClient", + "shortName": "TextServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceClient.batch_embed_text", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService.BatchEmbedText", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService", + "shortName": "TextService" + }, + "shortName": "BatchEmbedText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.BatchEmbedTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "texts", + "type": "MutableSequence[str]" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.BatchEmbedTextResponse", + "shortName": "batch_embed_text" + }, + "description": "Sample for BatchEmbedText", + "file": "generativelanguage_v1alpha_generated_text_service_batch_embed_text_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_TextService_BatchEmbedText_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_text_service_batch_embed_text_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceAsyncClient", + "shortName": "TextServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceAsyncClient.count_text_tokens", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService.CountTextTokens", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService", + "shortName": "TextService" + }, + "shortName": "CountTextTokens" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CountTextTokensRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1alpha.types.TextPrompt" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.CountTextTokensResponse", + "shortName": "count_text_tokens" + }, + "description": "Sample for CountTextTokens", + "file": "generativelanguage_v1alpha_generated_text_service_count_text_tokens_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_TextService_CountTextTokens_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_text_service_count_text_tokens_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceClient", + "shortName": "TextServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceClient.count_text_tokens", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService.CountTextTokens", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService", + "shortName": "TextService" + }, + "shortName": "CountTextTokens" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.CountTextTokensRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1alpha.types.TextPrompt" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.CountTextTokensResponse", + "shortName": "count_text_tokens" + }, + "description": "Sample for CountTextTokens", + "file": "generativelanguage_v1alpha_generated_text_service_count_text_tokens_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_TextService_CountTextTokens_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_text_service_count_text_tokens_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceAsyncClient", + "shortName": "TextServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceAsyncClient.embed_text", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService.EmbedText", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService", + "shortName": "TextService" + }, + "shortName": "EmbedText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.EmbedTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "text", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.EmbedTextResponse", + "shortName": "embed_text" + }, + "description": "Sample for EmbedText", + "file": "generativelanguage_v1alpha_generated_text_service_embed_text_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_TextService_EmbedText_async", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_text_service_embed_text_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceClient", + "shortName": "TextServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceClient.embed_text", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService.EmbedText", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService", + "shortName": "TextService" + }, + "shortName": "EmbedText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.EmbedTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "text", + "type": "str" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.EmbedTextResponse", + "shortName": "embed_text" + }, + "description": "Sample for EmbedText", + "file": "generativelanguage_v1alpha_generated_text_service_embed_text_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_TextService_EmbedText_sync", + "segments": [ + { + "end": 51, + "start": 27, + "type": "FULL" + }, + { + "end": 51, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 45, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 48, + "start": 46, + "type": "REQUEST_EXECUTION" + }, + { + "end": 52, + "start": 49, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_text_service_embed_text_sync.py" + }, + { + "canonical": true, + "clientMethod": { + "async": true, + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceAsyncClient", + "shortName": "TextServiceAsyncClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceAsyncClient.generate_text", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService.GenerateText", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService", + "shortName": "TextService" + }, + "shortName": "GenerateText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GenerateTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1alpha.types.TextPrompt" + }, + { + "name": "temperature", + "type": "float" + }, + { + "name": "candidate_count", + "type": "int" + }, + { + "name": "max_output_tokens", + "type": "int" + }, + { + "name": "top_p", + "type": "float" + }, + { + "name": "top_k", + "type": "int" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.GenerateTextResponse", + "shortName": "generate_text" + }, + "description": "Sample for GenerateText", + "file": "generativelanguage_v1alpha_generated_text_service_generate_text_async.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_TextService_GenerateText_async", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_text_service_generate_text_async.py" + }, + { + "canonical": true, + "clientMethod": { + "client": { + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceClient", + "shortName": "TextServiceClient" + }, + "fullName": "google.ai.generativelanguage_v1alpha.TextServiceClient.generate_text", + "method": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService.GenerateText", + "service": { + "fullName": "google.ai.generativelanguage.v1alpha.TextService", + "shortName": "TextService" + }, + "shortName": "GenerateText" + }, + "parameters": [ + { + "name": "request", + "type": "google.ai.generativelanguage_v1alpha.types.GenerateTextRequest" + }, + { + "name": "model", + "type": "str" + }, + { + "name": "prompt", + "type": "google.ai.generativelanguage_v1alpha.types.TextPrompt" + }, + { + "name": "temperature", + "type": "float" + }, + { + "name": "candidate_count", + "type": "int" + }, + { + "name": "max_output_tokens", + "type": "int" + }, + { + "name": "top_p", + "type": "float" + }, + { + "name": "top_k", + "type": "int" + }, + { + "name": "retry", + "type": "google.api_core.retry.Retry" + }, + { + "name": "timeout", + "type": "float" + }, + { + "name": "metadata", + "type": "Sequence[Tuple[str, Union[str, bytes]]]" + } + ], + "resultType": "google.ai.generativelanguage_v1alpha.types.GenerateTextResponse", + "shortName": "generate_text" + }, + "description": "Sample for GenerateText", + "file": "generativelanguage_v1alpha_generated_text_service_generate_text_sync.py", + "language": "PYTHON", + "origin": "API_DEFINITION", + "regionTag": "generativelanguage_v1alpha_generated_TextService_GenerateText_sync", + "segments": [ + { + "end": 55, + "start": 27, + "type": "FULL" + }, + { + "end": 55, + "start": 27, + "type": "SHORT" + }, + { + "end": 40, + "start": 38, + "type": "CLIENT_INITIALIZATION" + }, + { + "end": 49, + "start": 41, + "type": "REQUEST_INITIALIZATION" + }, + { + "end": 52, + "start": 50, + "type": "REQUEST_EXECUTION" + }, + { + "end": 56, + "start": 53, + "type": "RESPONSE_HANDLING" + } + ], + "title": "generativelanguage_v1alpha_generated_text_service_generate_text_sync.py" + } + ] +} diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta.json b/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta.json index 3f701f0c95ae..ad4cc6847367 100644 --- a/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta.json +++ b/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-ai-generativelanguage", - "version": "0.6.14" + "version": "0.1.0" }, "snippets": [ { diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta2.json b/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta2.json index 6711a3cbfb9f..8eb2001adab2 100644 --- a/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta2.json +++ b/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta2.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-ai-generativelanguage", - "version": "0.6.14" + "version": "0.1.0" }, "snippets": [ { diff --git a/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta3.json b/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta3.json index 117d37b5d52f..2d108f936e84 100644 --- a/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta3.json +++ b/packages/google-ai-generativelanguage/samples/generated_samples/snippet_metadata_google.ai.generativelanguage.v1beta3.json @@ -8,7 +8,7 @@ ], "language": "PYTHON", "name": "google-ai-generativelanguage", - "version": "0.6.14" + "version": "0.1.0" }, "snippets": [ { diff --git a/packages/google-ai-generativelanguage/scripts/fixup_generativelanguage_v1alpha_keywords.py b/packages/google-ai-generativelanguage/scripts/fixup_generativelanguage_v1alpha_keywords.py new file mode 100644 index 000000000000..1e4d9e291ab0 --- /dev/null +++ b/packages/google-ai-generativelanguage/scripts/fixup_generativelanguage_v1alpha_keywords.py @@ -0,0 +1,231 @@ +#! /usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import argparse +import os +import libcst as cst +import pathlib +import sys +from typing import (Any, Callable, Dict, List, Sequence, Tuple) + + +def partition( + predicate: Callable[[Any], bool], + iterator: Sequence[Any] +) -> Tuple[List[Any], List[Any]]: + """A stable, out-of-place partition.""" + results = ([], []) + + for i in iterator: + results[int(predicate(i))].append(i) + + # Returns trueList, falseList + return results[1], results[0] + + +class generativelanguageCallTransformer(cst.CSTTransformer): + CTRL_PARAMS: Tuple[str] = ('retry', 'timeout', 'metadata') + METHOD_TO_PARAMS: Dict[str, Tuple[str]] = { + 'batch_create_chunks': ('requests', 'parent', ), + 'batch_delete_chunks': ('requests', 'parent', ), + 'batch_embed_contents': ('model', 'requests', ), + 'batch_embed_text': ('model', 'texts', 'requests', ), + 'batch_update_chunks': ('requests', 'parent', ), + 'bidi_generate_content': ('setup', 'client_content', 'realtime_input', 'tool_response', ), + 'count_message_tokens': ('model', 'prompt', ), + 'count_text_tokens': ('model', 'prompt', ), + 'count_tokens': ('model', 'contents', 'generate_content_request', ), + 'create_cached_content': ('cached_content', ), + 'create_chunk': ('parent', 'chunk', ), + 'create_corpus': ('corpus', ), + 'create_document': ('parent', 'document', ), + 'create_file': ('file', ), + 'create_permission': ('parent', 'permission', ), + 'create_tuned_model': ('tuned_model', 'tuned_model_id', ), + 'delete_cached_content': ('name', ), + 'delete_chunk': ('name', ), + 'delete_corpus': ('name', 'force', ), + 'delete_document': ('name', 'force', ), + 'delete_file': ('name', ), + 'delete_permission': ('name', ), + 'delete_tuned_model': ('name', ), + 'embed_content': ('model', 'content', 'task_type', 'title', 'output_dimensionality', ), + 'embed_text': ('model', 'text', ), + 'generate_answer': ('model', 'contents', 'answer_style', 'inline_passages', 'semantic_retriever', 'safety_settings', 'temperature', ), + 'generate_content': ('model', 'contents', 'system_instruction', 'tools', 'tool_config', 'safety_settings', 'generation_config', 'cached_content', ), + 'generate_message': ('model', 'prompt', 'temperature', 'candidate_count', 'top_p', 'top_k', ), + 'generate_text': ('model', 'prompt', 'temperature', 'candidate_count', 'max_output_tokens', 'top_p', 'top_k', 'safety_settings', 'stop_sequences', ), + 'get_cached_content': ('name', ), + 'get_chunk': ('name', ), + 'get_corpus': ('name', ), + 'get_document': ('name', ), + 'get_file': ('name', ), + 'get_model': ('name', ), + 'get_permission': ('name', ), + 'get_tuned_model': ('name', ), + 'list_cached_contents': ('page_size', 'page_token', ), + 'list_chunks': ('parent', 'page_size', 'page_token', ), + 'list_corpora': ('page_size', 'page_token', ), + 'list_documents': ('parent', 'page_size', 'page_token', ), + 'list_files': ('page_size', 'page_token', ), + 'list_models': ('page_size', 'page_token', ), + 'list_permissions': ('parent', 'page_size', 'page_token', ), + 'list_tuned_models': ('page_size', 'page_token', 'filter', ), + 'predict': ('model', 'instances', 'parameters', ), + 'query_corpus': ('name', 'query', 'metadata_filters', 'results_count', ), + 'query_document': ('name', 'query', 'results_count', 'metadata_filters', ), + 'stream_generate_content': ('model', 'contents', 'system_instruction', 'tools', 'tool_config', 'safety_settings', 'generation_config', 'cached_content', ), + 'transfer_ownership': ('name', 'email_address', ), + 'update_cached_content': ('cached_content', 'update_mask', ), + 'update_chunk': ('chunk', 'update_mask', ), + 'update_corpus': ('corpus', 'update_mask', ), + 'update_document': ('document', 'update_mask', ), + 'update_permission': ('permission', 'update_mask', ), + 'update_tuned_model': ('tuned_model', 'update_mask', ), + } + + def leave_Call(self, original: cst.Call, updated: cst.Call) -> cst.CSTNode: + try: + key = original.func.attr.value + kword_params = self.METHOD_TO_PARAMS[key] + except (AttributeError, KeyError): + # Either not a method from the API or too convoluted to be sure. + return updated + + # If the existing code is valid, keyword args come after positional args. + # Therefore, all positional args must map to the first parameters. + args, kwargs = partition(lambda a: not bool(a.keyword), updated.args) + if any(k.keyword.value == "request" for k in kwargs): + # We've already fixed this file, don't fix it again. + return updated + + kwargs, ctrl_kwargs = partition( + lambda a: a.keyword.value not in self.CTRL_PARAMS, + kwargs + ) + + args, ctrl_args = args[:len(kword_params)], args[len(kword_params):] + ctrl_kwargs.extend(cst.Arg(value=a.value, keyword=cst.Name(value=ctrl)) + for a, ctrl in zip(ctrl_args, self.CTRL_PARAMS)) + + request_arg = cst.Arg( + value=cst.Dict([ + cst.DictElement( + cst.SimpleString("'{}'".format(name)), +cst.Element(value=arg.value) + ) + # Note: the args + kwargs looks silly, but keep in mind that + # the control parameters had to be stripped out, and that + # those could have been passed positionally or by keyword. + for name, arg in zip(kword_params, args + kwargs)]), + keyword=cst.Name("request") + ) + + return updated.with_changes( + args=[request_arg] + ctrl_kwargs + ) + + +def fix_files( + in_dir: pathlib.Path, + out_dir: pathlib.Path, + *, + transformer=generativelanguageCallTransformer(), +): + """Duplicate the input dir to the output dir, fixing file method calls. + + Preconditions: + * in_dir is a real directory + * out_dir is a real, empty directory + """ + pyfile_gen = ( + pathlib.Path(os.path.join(root, f)) + for root, _, files in os.walk(in_dir) + for f in files if os.path.splitext(f)[1] == ".py" + ) + + for fpath in pyfile_gen: + with open(fpath, 'r') as f: + src = f.read() + + # Parse the code and insert method call fixes. + tree = cst.parse_module(src) + updated = tree.visit(transformer) + + # Create the path and directory structure for the new file. + updated_path = out_dir.joinpath(fpath.relative_to(in_dir)) + updated_path.parent.mkdir(parents=True, exist_ok=True) + + # Generate the updated source file at the corresponding path. + with open(updated_path, 'w') as f: + f.write(updated.code) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description="""Fix up source that uses the generativelanguage client library. + +The existing sources are NOT overwritten but are copied to output_dir with changes made. + +Note: This tool operates at a best-effort level at converting positional + parameters in client method calls to keyword based parameters. + Cases where it WILL FAIL include + A) * or ** expansion in a method call. + B) Calls via function or method alias (includes free function calls) + C) Indirect or dispatched calls (e.g. the method is looked up dynamically) + + These all constitute false negatives. The tool will also detect false + positives when an API method shares a name with another method. +""") + parser.add_argument( + '-d', + '--input-directory', + required=True, + dest='input_dir', + help='the input directory to walk for python files to fix up', + ) + parser.add_argument( + '-o', + '--output-directory', + required=True, + dest='output_dir', + help='the directory to output files fixed via un-flattening', + ) + args = parser.parse_args() + input_dir = pathlib.Path(args.input_dir) + output_dir = pathlib.Path(args.output_dir) + if not input_dir.is_dir(): + print( + f"input directory '{input_dir}' does not exist or is not a directory", + file=sys.stderr, + ) + sys.exit(-1) + + if not output_dir.is_dir(): + print( + f"output directory '{output_dir}' does not exist or is not a directory", + file=sys.stderr, + ) + sys.exit(-1) + + if os.listdir(output_dir): + print( + f"output directory '{output_dir}' is not empty", + file=sys.stderr, + ) + sys.exit(-1) + + fix_files(input_dir, output_dir) diff --git a/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/__init__.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/__init__.py new file mode 100644 index 000000000000..8f6cf068242c --- /dev/null +++ b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/__init__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_cache_service.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_cache_service.py new file mode 100644 index 000000000000..47c544657934 --- /dev/null +++ b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_cache_service.py @@ -0,0 +1,6111 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +from collections.abc import AsyncIterable, Iterable +import json +import math + +from google.api_core import api_core_version +from google.protobuf import json_format +import grpc +from grpc.experimental import aio +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule, TimestampRule +import pytest +from requests import PreparedRequest, Request, Response +from requests.sessions import Session + +try: + from google.auth.aio import credentials as ga_credentials_async + + HAS_GOOGLE_AUTH_AIO = True +except ImportError: # pragma: NO COVER + HAS_GOOGLE_AUTH_AIO = False + +from google.api_core import gapic_v1, grpc_helpers, grpc_helpers_async, path_template +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +import google.auth +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account +from google.protobuf import duration_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import struct_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.cache_service import ( + CacheServiceAsyncClient, + CacheServiceClient, + pagers, + transports, +) +from google.ai.generativelanguage_v1alpha.types import ( + cached_content as gag_cached_content, +) +from google.ai.generativelanguage_v1alpha.types import cache_service +from google.ai.generativelanguage_v1alpha.types import cached_content +from google.ai.generativelanguage_v1alpha.types import content + + +async def mock_async_gen(data, chunk_size=1): + for i in range(0, len(data)): # pragma: NO COVER + chunk = data[i : i + chunk_size] + yield chunk.encode("utf-8") + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# TODO: use async auth anon credentials by default once the minimum version of google-auth is upgraded. +# See related issue: https://github.com/googleapis/gapic-generator-python/issues/2107. +def async_anonymous_credentials(): + if HAS_GOOGLE_AUTH_AIO: + return ga_credentials_async.AnonymousCredentials() + return ga_credentials.AnonymousCredentials() + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +# If default endpoint template is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint template so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint_template(client): + return ( + "test.{UNIVERSE_DOMAIN}" + if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE) + else client._DEFAULT_ENDPOINT_TEMPLATE + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert CacheServiceClient._get_default_mtls_endpoint(None) is None + assert ( + CacheServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + CacheServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + CacheServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + CacheServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert CacheServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +def test__read_environment_variables(): + assert CacheServiceClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert CacheServiceClient._read_environment_variables() == (True, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + assert CacheServiceClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + CacheServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + assert CacheServiceClient._read_environment_variables() == ( + False, + "never", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + assert CacheServiceClient._read_environment_variables() == ( + False, + "always", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + assert CacheServiceClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + CacheServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}): + assert CacheServiceClient._read_environment_variables() == ( + False, + "auto", + "foo.com", + ) + + +def test__get_client_cert_source(): + mock_provided_cert_source = mock.Mock() + mock_default_cert_source = mock.Mock() + + assert CacheServiceClient._get_client_cert_source(None, False) is None + assert ( + CacheServiceClient._get_client_cert_source(mock_provided_cert_source, False) + is None + ) + assert ( + CacheServiceClient._get_client_cert_source(mock_provided_cert_source, True) + == mock_provided_cert_source + ) + + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", return_value=True + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_default_cert_source, + ): + assert ( + CacheServiceClient._get_client_cert_source(None, True) + is mock_default_cert_source + ) + assert ( + CacheServiceClient._get_client_cert_source( + mock_provided_cert_source, "true" + ) + is mock_provided_cert_source + ) + + +@mock.patch.object( + CacheServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(CacheServiceClient), +) +@mock.patch.object( + CacheServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(CacheServiceAsyncClient), +) +def test__get_api_endpoint(): + api_override = "foo.com" + mock_client_cert_source = mock.Mock() + default_universe = CacheServiceClient._DEFAULT_UNIVERSE + default_endpoint = CacheServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = CacheServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + assert ( + CacheServiceClient._get_api_endpoint( + api_override, mock_client_cert_source, default_universe, "always" + ) + == api_override + ) + assert ( + CacheServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "auto" + ) + == CacheServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + CacheServiceClient._get_api_endpoint(None, None, default_universe, "auto") + == default_endpoint + ) + assert ( + CacheServiceClient._get_api_endpoint(None, None, default_universe, "always") + == CacheServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + CacheServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "always" + ) + == CacheServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + CacheServiceClient._get_api_endpoint(None, None, mock_universe, "never") + == mock_endpoint + ) + assert ( + CacheServiceClient._get_api_endpoint(None, None, default_universe, "never") + == default_endpoint + ) + + with pytest.raises(MutualTLSChannelError) as excinfo: + CacheServiceClient._get_api_endpoint( + None, mock_client_cert_source, mock_universe, "auto" + ) + assert ( + str(excinfo.value) + == "mTLS is not supported in any universe other than googleapis.com." + ) + + +def test__get_universe_domain(): + client_universe_domain = "foo.com" + universe_domain_env = "bar.com" + + assert ( + CacheServiceClient._get_universe_domain( + client_universe_domain, universe_domain_env + ) + == client_universe_domain + ) + assert ( + CacheServiceClient._get_universe_domain(None, universe_domain_env) + == universe_domain_env + ) + assert ( + CacheServiceClient._get_universe_domain(None, None) + == CacheServiceClient._DEFAULT_UNIVERSE + ) + + with pytest.raises(ValueError) as excinfo: + CacheServiceClient._get_universe_domain("", None) + assert str(excinfo.value) == "Universe Domain cannot be an empty string." + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (CacheServiceClient, "grpc"), + (CacheServiceAsyncClient, "grpc_asyncio"), + (CacheServiceClient, "rest"), + ], +) +def test_cache_service_client_from_service_account_info(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.CacheServiceGrpcTransport, "grpc"), + (transports.CacheServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.CacheServiceRestTransport, "rest"), + ], +) +def test_cache_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (CacheServiceClient, "grpc"), + (CacheServiceAsyncClient, "grpc_asyncio"), + (CacheServiceClient, "rest"), + ], +) +def test_cache_service_client_from_service_account_file(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +def test_cache_service_client_get_transport_class(): + transport = CacheServiceClient.get_transport_class() + available_transports = [ + transports.CacheServiceGrpcTransport, + transports.CacheServiceRestTransport, + ] + assert transport in available_transports + + transport = CacheServiceClient.get_transport_class("grpc") + assert transport == transports.CacheServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (CacheServiceClient, transports.CacheServiceGrpcTransport, "grpc"), + ( + CacheServiceAsyncClient, + transports.CacheServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (CacheServiceClient, transports.CacheServiceRestTransport, "rest"), + ], +) +@mock.patch.object( + CacheServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(CacheServiceClient), +) +@mock.patch.object( + CacheServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(CacheServiceAsyncClient), +) +def test_cache_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(CacheServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(CacheServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (CacheServiceClient, transports.CacheServiceGrpcTransport, "grpc", "true"), + ( + CacheServiceAsyncClient, + transports.CacheServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (CacheServiceClient, transports.CacheServiceGrpcTransport, "grpc", "false"), + ( + CacheServiceAsyncClient, + transports.CacheServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + (CacheServiceClient, transports.CacheServiceRestTransport, "rest", "true"), + (CacheServiceClient, transports.CacheServiceRestTransport, "rest", "false"), + ], +) +@mock.patch.object( + CacheServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(CacheServiceClient), +) +@mock.patch.object( + CacheServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(CacheServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_cache_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class", [CacheServiceClient, CacheServiceAsyncClient]) +@mock.patch.object( + CacheServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(CacheServiceClient) +) +@mock.patch.object( + CacheServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(CacheServiceAsyncClient), +) +def test_cache_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + +@pytest.mark.parametrize("client_class", [CacheServiceClient, CacheServiceAsyncClient]) +@mock.patch.object( + CacheServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(CacheServiceClient), +) +@mock.patch.object( + CacheServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(CacheServiceAsyncClient), +) +def test_cache_service_client_client_api_endpoint(client_class): + mock_client_cert_source = client_cert_source_callback + api_override = "foo.com" + default_universe = CacheServiceClient._DEFAULT_UNIVERSE + default_endpoint = CacheServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = CacheServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true", + # use ClientOptions.api_endpoint as the api endpoint regardless. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ): + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=api_override + ) + client = client_class( + client_options=options, + credentials=ga_credentials.AnonymousCredentials(), + ) + assert client.api_endpoint == api_override + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == default_endpoint + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always", + # use the DEFAULT_MTLS_ENDPOINT as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + + # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default), + # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist, + # and ClientOptions.universe_domain="bar.com", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint. + options = client_options.ClientOptions() + universe_exists = hasattr(options, "universe_domain") + if universe_exists: + options = client_options.ClientOptions(universe_domain=mock_universe) + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + else: + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == ( + mock_endpoint if universe_exists else default_endpoint + ) + assert client.universe_domain == ( + mock_universe if universe_exists else default_universe + ) + + # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + options = client_options.ClientOptions() + if hasattr(options, "universe_domain"): + delattr(options, "universe_domain") + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == default_endpoint + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (CacheServiceClient, transports.CacheServiceGrpcTransport, "grpc"), + ( + CacheServiceAsyncClient, + transports.CacheServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (CacheServiceClient, transports.CacheServiceRestTransport, "rest"), + ], +) +def test_cache_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + CacheServiceClient, + transports.CacheServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + CacheServiceAsyncClient, + transports.CacheServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + (CacheServiceClient, transports.CacheServiceRestTransport, "rest", None), + ], +) +def test_cache_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +def test_cache_service_client_client_options_from_dict(): + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.cache_service.transports.CacheServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = CacheServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + CacheServiceClient, + transports.CacheServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + CacheServiceAsyncClient, + transports.CacheServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_cache_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=(), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + cache_service.ListCachedContentsRequest, + dict, + ], +) +def test_list_cached_contents(request_type, transport: str = "grpc"): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_cached_contents), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cache_service.ListCachedContentsResponse( + next_page_token="next_page_token_value", + ) + response = client.list_cached_contents(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = cache_service.ListCachedContentsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListCachedContentsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_cached_contents_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = cache_service.ListCachedContentsRequest( + page_token="page_token_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_cached_contents), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_cached_contents(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cache_service.ListCachedContentsRequest( + page_token="page_token_value", + ) + + +def test_list_cached_contents_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_cached_contents in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_cached_contents + ] = mock_rpc + request = {} + client.list_cached_contents(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_cached_contents(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_cached_contents_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_cached_contents + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.list_cached_contents + ] = mock_rpc + + request = {} + await client.list_cached_contents(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.list_cached_contents(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_cached_contents_async( + transport: str = "grpc_asyncio", + request_type=cache_service.ListCachedContentsRequest, +): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_cached_contents), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cache_service.ListCachedContentsResponse( + next_page_token="next_page_token_value", + ) + ) + response = await client.list_cached_contents(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = cache_service.ListCachedContentsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListCachedContentsAsyncPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_cached_contents_async_from_dict(): + await test_list_cached_contents_async(request_type=dict) + + +def test_list_cached_contents_pager(transport_name: str = "grpc"): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_cached_contents), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + cache_service.ListCachedContentsResponse( + cached_contents=[ + cached_content.CachedContent(), + cached_content.CachedContent(), + cached_content.CachedContent(), + ], + next_page_token="abc", + ), + cache_service.ListCachedContentsResponse( + cached_contents=[], + next_page_token="def", + ), + cache_service.ListCachedContentsResponse( + cached_contents=[ + cached_content.CachedContent(), + ], + next_page_token="ghi", + ), + cache_service.ListCachedContentsResponse( + cached_contents=[ + cached_content.CachedContent(), + cached_content.CachedContent(), + ], + ), + RuntimeError, + ) + + expected_metadata = () + retry = retries.Retry() + timeout = 5 + pager = client.list_cached_contents(request={}, retry=retry, timeout=timeout) + + assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, cached_content.CachedContent) for i in results) + + +def test_list_cached_contents_pages(transport_name: str = "grpc"): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_cached_contents), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + cache_service.ListCachedContentsResponse( + cached_contents=[ + cached_content.CachedContent(), + cached_content.CachedContent(), + cached_content.CachedContent(), + ], + next_page_token="abc", + ), + cache_service.ListCachedContentsResponse( + cached_contents=[], + next_page_token="def", + ), + cache_service.ListCachedContentsResponse( + cached_contents=[ + cached_content.CachedContent(), + ], + next_page_token="ghi", + ), + cache_service.ListCachedContentsResponse( + cached_contents=[ + cached_content.CachedContent(), + cached_content.CachedContent(), + ], + ), + RuntimeError, + ) + pages = list(client.list_cached_contents(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_cached_contents_async_pager(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_cached_contents), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + cache_service.ListCachedContentsResponse( + cached_contents=[ + cached_content.CachedContent(), + cached_content.CachedContent(), + cached_content.CachedContent(), + ], + next_page_token="abc", + ), + cache_service.ListCachedContentsResponse( + cached_contents=[], + next_page_token="def", + ), + cache_service.ListCachedContentsResponse( + cached_contents=[ + cached_content.CachedContent(), + ], + next_page_token="ghi", + ), + cache_service.ListCachedContentsResponse( + cached_contents=[ + cached_content.CachedContent(), + cached_content.CachedContent(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_cached_contents( + request={}, + ) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, cached_content.CachedContent) for i in responses) + + +@pytest.mark.asyncio +async def test_list_cached_contents_async_pages(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_cached_contents), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + cache_service.ListCachedContentsResponse( + cached_contents=[ + cached_content.CachedContent(), + cached_content.CachedContent(), + cached_content.CachedContent(), + ], + next_page_token="abc", + ), + cache_service.ListCachedContentsResponse( + cached_contents=[], + next_page_token="def", + ), + cache_service.ListCachedContentsResponse( + cached_contents=[ + cached_content.CachedContent(), + ], + next_page_token="ghi", + ), + cache_service.ListCachedContentsResponse( + cached_contents=[ + cached_content.CachedContent(), + cached_content.CachedContent(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_cached_contents(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + cache_service.CreateCachedContentRequest, + dict, + ], +) +def test_create_cached_content(request_type, transport: str = "grpc"): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gag_cached_content.CachedContent( + name="name_value", + display_name="display_name_value", + model="model_value", + ) + response = client.create_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = cache_service.CreateCachedContentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_cached_content.CachedContent) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.model == "model_value" + + +def test_create_cached_content_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = cache_service.CreateCachedContentRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_cached_content), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.create_cached_content(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cache_service.CreateCachedContentRequest() + + +def test_create_cached_content_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_cached_content + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_cached_content + ] = mock_rpc + request = {} + client.create_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_cached_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_cached_content_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_cached_content + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.create_cached_content + ] = mock_rpc + + request = {} + await client.create_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.create_cached_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_cached_content_async( + transport: str = "grpc_asyncio", + request_type=cache_service.CreateCachedContentRequest, +): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_cached_content.CachedContent( + name="name_value", + display_name="display_name_value", + model="model_value", + ) + ) + response = await client.create_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = cache_service.CreateCachedContentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_cached_content.CachedContent) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.model == "model_value" + + +@pytest.mark.asyncio +async def test_create_cached_content_async_from_dict(): + await test_create_cached_content_async(request_type=dict) + + +def test_create_cached_content_flattened(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gag_cached_content.CachedContent() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_cached_content( + cached_content=gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].cached_content + mock_val = gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ) + assert arg == mock_val + + +def test_create_cached_content_flattened_error(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_cached_content( + cache_service.CreateCachedContentRequest(), + cached_content=gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ), + ) + + +@pytest.mark.asyncio +async def test_create_cached_content_flattened_async(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gag_cached_content.CachedContent() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_cached_content.CachedContent() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_cached_content( + cached_content=gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].cached_content + mock_val = gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ) + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_create_cached_content_flattened_error_async(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_cached_content( + cache_service.CreateCachedContentRequest(), + cached_content=gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ), + ) + + +@pytest.mark.parametrize( + "request_type", + [ + cache_service.GetCachedContentRequest, + dict, + ], +) +def test_get_cached_content(request_type, transport: str = "grpc"): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cached_content.CachedContent( + name="name_value", + display_name="display_name_value", + model="model_value", + ) + response = client.get_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = cache_service.GetCachedContentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, cached_content.CachedContent) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.model == "model_value" + + +def test_get_cached_content_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = cache_service.GetCachedContentRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_cached_content), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.get_cached_content(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cache_service.GetCachedContentRequest( + name="name_value", + ) + + +def test_get_cached_content_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_cached_content in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_cached_content + ] = mock_rpc + request = {} + client.get_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_cached_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_cached_content_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_cached_content + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.get_cached_content + ] = mock_rpc + + request = {} + await client.get_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.get_cached_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_cached_content_async( + transport: str = "grpc_asyncio", request_type=cache_service.GetCachedContentRequest +): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cached_content.CachedContent( + name="name_value", + display_name="display_name_value", + model="model_value", + ) + ) + response = await client.get_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = cache_service.GetCachedContentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, cached_content.CachedContent) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.model == "model_value" + + +@pytest.mark.asyncio +async def test_get_cached_content_async_from_dict(): + await test_get_cached_content_async(request_type=dict) + + +def test_get_cached_content_field_headers(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cache_service.GetCachedContentRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_cached_content), "__call__" + ) as call: + call.return_value = cached_content.CachedContent() + client.get_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_cached_content_field_headers_async(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cache_service.GetCachedContentRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_cached_content), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cached_content.CachedContent() + ) + await client.get_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_get_cached_content_flattened(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cached_content.CachedContent() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_cached_content( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_get_cached_content_flattened_error(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_cached_content( + cache_service.GetCachedContentRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_cached_content_flattened_async(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.get_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = cached_content.CachedContent() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cached_content.CachedContent() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_cached_content( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_get_cached_content_flattened_error_async(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_cached_content( + cache_service.GetCachedContentRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + cache_service.UpdateCachedContentRequest, + dict, + ], +) +def test_update_cached_content(request_type, transport: str = "grpc"): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gag_cached_content.CachedContent( + name="name_value", + display_name="display_name_value", + model="model_value", + ) + response = client.update_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = cache_service.UpdateCachedContentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_cached_content.CachedContent) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.model == "model_value" + + +def test_update_cached_content_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = cache_service.UpdateCachedContentRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_cached_content), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.update_cached_content(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cache_service.UpdateCachedContentRequest() + + +def test_update_cached_content_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_cached_content + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_cached_content + ] = mock_rpc + request = {} + client.update_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_cached_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_update_cached_content_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_cached_content + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.update_cached_content + ] = mock_rpc + + request = {} + await client.update_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.update_cached_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_update_cached_content_async( + transport: str = "grpc_asyncio", + request_type=cache_service.UpdateCachedContentRequest, +): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_cached_content.CachedContent( + name="name_value", + display_name="display_name_value", + model="model_value", + ) + ) + response = await client.update_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = cache_service.UpdateCachedContentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_cached_content.CachedContent) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.model == "model_value" + + +@pytest.mark.asyncio +async def test_update_cached_content_async_from_dict(): + await test_update_cached_content_async(request_type=dict) + + +def test_update_cached_content_field_headers(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cache_service.UpdateCachedContentRequest() + + request.cached_content.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_cached_content), "__call__" + ) as call: + call.return_value = gag_cached_content.CachedContent() + client.update_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "cached_content.name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_cached_content_field_headers_async(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cache_service.UpdateCachedContentRequest() + + request.cached_content.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_cached_content), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_cached_content.CachedContent() + ) + await client.update_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "cached_content.name=name_value", + ) in kw["metadata"] + + +def test_update_cached_content_flattened(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gag_cached_content.CachedContent() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_cached_content( + cached_content=gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].cached_content + mock_val = gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ) + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +def test_update_cached_content_flattened_error(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_cached_content( + cache_service.UpdateCachedContentRequest(), + cached_content=gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_cached_content_flattened_async(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gag_cached_content.CachedContent() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_cached_content.CachedContent() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_cached_content( + cached_content=gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].cached_content + mock_val = gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ) + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_update_cached_content_flattened_error_async(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_cached_content( + cache_service.UpdateCachedContentRequest(), + cached_content=gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.parametrize( + "request_type", + [ + cache_service.DeleteCachedContentRequest, + dict, + ], +) +def test_delete_cached_content(request_type, transport: str = "grpc"): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.delete_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = cache_service.DeleteCachedContentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_cached_content_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = cache_service.DeleteCachedContentRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_cached_content), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.delete_cached_content(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == cache_service.DeleteCachedContentRequest( + name="name_value", + ) + + +def test_delete_cached_content_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_cached_content + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_cached_content + ] = mock_rpc + request = {} + client.delete_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_cached_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_cached_content_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_cached_content + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_cached_content + ] = mock_rpc + + request = {} + await client.delete_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.delete_cached_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_cached_content_async( + transport: str = "grpc_asyncio", + request_type=cache_service.DeleteCachedContentRequest, +): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.delete_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = cache_service.DeleteCachedContentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_delete_cached_content_async_from_dict(): + await test_delete_cached_content_async(request_type=dict) + + +def test_delete_cached_content_field_headers(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cache_service.DeleteCachedContentRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_cached_content), "__call__" + ) as call: + call.return_value = None + client.delete_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_cached_content_field_headers_async(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = cache_service.DeleteCachedContentRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_cached_content), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_delete_cached_content_flattened(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_cached_content( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_delete_cached_content_flattened_error(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_cached_content( + cache_service.DeleteCachedContentRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_cached_content_flattened_async(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_cached_content( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_delete_cached_content_flattened_error_async(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_cached_content( + cache_service.DeleteCachedContentRequest(), + name="name_value", + ) + + +def test_list_cached_contents_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.list_cached_contents in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_cached_contents + ] = mock_rpc + + request = {} + client.list_cached_contents(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_cached_contents(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_list_cached_contents_rest_pager(transport: str = "rest"): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + cache_service.ListCachedContentsResponse( + cached_contents=[ + cached_content.CachedContent(), + cached_content.CachedContent(), + cached_content.CachedContent(), + ], + next_page_token="abc", + ), + cache_service.ListCachedContentsResponse( + cached_contents=[], + next_page_token="def", + ), + cache_service.ListCachedContentsResponse( + cached_contents=[ + cached_content.CachedContent(), + ], + next_page_token="ghi", + ), + cache_service.ListCachedContentsResponse( + cached_contents=[ + cached_content.CachedContent(), + cached_content.CachedContent(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + cache_service.ListCachedContentsResponse.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {} + + pager = client.list_cached_contents(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, cached_content.CachedContent) for i in results) + + pages = list(client.list_cached_contents(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_create_cached_content_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_cached_content + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_cached_content + ] = mock_rpc + + request = {} + client.create_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_cached_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_create_cached_content_rest_required_fields( + request_type=cache_service.CreateCachedContentRequest, +): + transport_class = transports.CacheServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_cached_content._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_cached_content._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = gag_cached_content.CachedContent() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = gag_cached_content.CachedContent.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.create_cached_content(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_cached_content_rest_unset_required_fields(): + transport = transports.CacheServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.create_cached_content._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("cachedContent",))) + + +def test_create_cached_content_rest_flattened(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = gag_cached_content.CachedContent() + + # get arguments that satisfy an http rule for this method + sample_request = {} + + # get truthy value for each flattened field + mock_args = dict( + cached_content=gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = gag_cached_content.CachedContent.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.create_cached_content(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/cachedContents" % client.transport._host, args[1] + ) + + +def test_create_cached_content_rest_flattened_error(transport: str = "rest"): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_cached_content( + cache_service.CreateCachedContentRequest(), + cached_content=gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ), + ) + + +def test_get_cached_content_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.get_cached_content in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.get_cached_content + ] = mock_rpc + + request = {} + client.get_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_cached_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_get_cached_content_rest_required_fields( + request_type=cache_service.GetCachedContentRequest, +): + transport_class = transports.CacheServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_cached_content._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_cached_content._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = cached_content.CachedContent() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = cached_content.CachedContent.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_cached_content(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_cached_content_rest_unset_required_fields(): + transport = transports.CacheServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_cached_content._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +def test_get_cached_content_rest_flattened(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = cached_content.CachedContent() + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "cachedContents/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = cached_content.CachedContent.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.get_cached_content(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{name=cachedContents/*}" % client.transport._host, args[1] + ) + + +def test_get_cached_content_rest_flattened_error(transport: str = "rest"): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_cached_content( + cache_service.GetCachedContentRequest(), + name="name_value", + ) + + +def test_update_cached_content_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_cached_content + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_cached_content + ] = mock_rpc + + request = {} + client.update_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_cached_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_update_cached_content_rest_required_fields( + request_type=cache_service.UpdateCachedContentRequest, +): + transport_class = transports.CacheServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_cached_content._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_cached_content._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("update_mask",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = gag_cached_content.CachedContent() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "patch", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = gag_cached_content.CachedContent.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.update_cached_content(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_update_cached_content_rest_unset_required_fields(): + transport = transports.CacheServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.update_cached_content._get_unset_required_fields({}) + assert set(unset_fields) == (set(("updateMask",)) & set(("cachedContent",))) + + +def test_update_cached_content_rest_flattened(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = gag_cached_content.CachedContent() + + # get arguments that satisfy an http rule for this method + sample_request = {"cached_content": {"name": "cachedContents/sample1"}} + + # get truthy value for each flattened field + mock_args = dict( + cached_content=gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = gag_cached_content.CachedContent.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.update_cached_content(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{cached_content.name=cachedContents/*}" + % client.transport._host, + args[1], + ) + + +def test_update_cached_content_rest_flattened_error(transport: str = "rest"): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_cached_content( + cache_service.UpdateCachedContentRequest(), + cached_content=gag_cached_content.CachedContent( + expire_time=timestamp_pb2.Timestamp(seconds=751) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +def test_delete_cached_content_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_cached_content + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_cached_content + ] = mock_rpc + + request = {} + client.delete_cached_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_cached_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_delete_cached_content_rest_required_fields( + request_type=cache_service.DeleteCachedContentRequest, +): + transport_class = transports.CacheServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_cached_content._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_cached_content._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.delete_cached_content(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_delete_cached_content_rest_unset_required_fields(): + transport = transports.CacheServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.delete_cached_content._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +def test_delete_cached_content_rest_flattened(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "cachedContents/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.delete_cached_content(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{name=cachedContents/*}" % client.transport._host, args[1] + ) + + +def test_delete_cached_content_rest_flattened_error(transport: str = "rest"): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_cached_content( + cache_service.DeleteCachedContentRequest(), + name="name_value", + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.CacheServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.CacheServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = CacheServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.CacheServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = CacheServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = CacheServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.CacheServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = CacheServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.CacheServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = CacheServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.CacheServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.CacheServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.CacheServiceGrpcTransport, + transports.CacheServiceGrpcAsyncIOTransport, + transports.CacheServiceRestTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_kind_grpc(): + transport = CacheServiceClient.get_transport_class("grpc")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "grpc" + + +def test_initialize_client_w_grpc(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_cached_contents_empty_call_grpc(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.list_cached_contents), "__call__" + ) as call: + call.return_value = cache_service.ListCachedContentsResponse() + client.list_cached_contents(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = cache_service.ListCachedContentsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_create_cached_content_empty_call_grpc(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.create_cached_content), "__call__" + ) as call: + call.return_value = gag_cached_content.CachedContent() + client.create_cached_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = cache_service.CreateCachedContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_cached_content_empty_call_grpc(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.get_cached_content), "__call__" + ) as call: + call.return_value = cached_content.CachedContent() + client.get_cached_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = cache_service.GetCachedContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_update_cached_content_empty_call_grpc(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.update_cached_content), "__call__" + ) as call: + call.return_value = gag_cached_content.CachedContent() + client.update_cached_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = cache_service.UpdateCachedContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_delete_cached_content_empty_call_grpc(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.delete_cached_content), "__call__" + ) as call: + call.return_value = None + client.delete_cached_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = cache_service.DeleteCachedContentRequest() + + assert args[0] == request_msg + + +def test_transport_kind_grpc_asyncio(): + transport = CacheServiceAsyncClient.get_transport_class("grpc_asyncio")( + credentials=async_anonymous_credentials() + ) + assert transport.kind == "grpc_asyncio" + + +def test_initialize_client_w_grpc_asyncio(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_list_cached_contents_empty_call_grpc_asyncio(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.list_cached_contents), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cache_service.ListCachedContentsResponse( + next_page_token="next_page_token_value", + ) + ) + await client.list_cached_contents(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = cache_service.ListCachedContentsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_create_cached_content_empty_call_grpc_asyncio(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.create_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_cached_content.CachedContent( + name="name_value", + display_name="display_name_value", + model="model_value", + ) + ) + await client.create_cached_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = cache_service.CreateCachedContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_get_cached_content_empty_call_grpc_asyncio(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.get_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + cached_content.CachedContent( + name="name_value", + display_name="display_name_value", + model="model_value", + ) + ) + await client.get_cached_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = cache_service.GetCachedContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_update_cached_content_empty_call_grpc_asyncio(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.update_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_cached_content.CachedContent( + name="name_value", + display_name="display_name_value", + model="model_value", + ) + ) + await client.update_cached_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = cache_service.UpdateCachedContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_delete_cached_content_empty_call_grpc_asyncio(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.delete_cached_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_cached_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = cache_service.DeleteCachedContentRequest() + + assert args[0] == request_msg + + +def test_transport_kind_rest(): + transport = CacheServiceClient.get_transport_class("rest")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "rest" + + +def test_list_cached_contents_rest_bad_request( + request_type=cache_service.ListCachedContentsRequest, +): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_cached_contents(request) + + +@pytest.mark.parametrize( + "request_type", + [ + cache_service.ListCachedContentsRequest, + dict, + ], +) +def test_list_cached_contents_rest_call_success(request_type): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = cache_service.ListCachedContentsResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = cache_service.ListCachedContentsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.list_cached_contents(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListCachedContentsPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_cached_contents_rest_interceptors(null_interceptor): + transport = transports.CacheServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.CacheServiceRestInterceptor(), + ) + client = CacheServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.CacheServiceRestInterceptor, "post_list_cached_contents" + ) as post, mock.patch.object( + transports.CacheServiceRestInterceptor, "pre_list_cached_contents" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = cache_service.ListCachedContentsRequest.pb( + cache_service.ListCachedContentsRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = cache_service.ListCachedContentsResponse.to_json( + cache_service.ListCachedContentsResponse() + ) + req.return_value.content = return_value + + request = cache_service.ListCachedContentsRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = cache_service.ListCachedContentsResponse() + + client.list_cached_contents( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_create_cached_content_rest_bad_request( + request_type=cache_service.CreateCachedContentRequest, +): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.create_cached_content(request) + + +@pytest.mark.parametrize( + "request_type", + [ + cache_service.CreateCachedContentRequest, + dict, + ], +) +def test_create_cached_content_rest_call_success(request_type): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {} + request_init["cached_content"] = { + "expire_time": {"seconds": 751, "nanos": 543}, + "ttl": {"seconds": 751, "nanos": 543}, + "name": "name_value", + "display_name": "display_name_value", + "model": "model_value", + "system_instruction": { + "parts": [ + { + "text": "text_value", + "inline_data": { + "mime_type": "mime_type_value", + "data": b"data_blob", + }, + "function_call": { + "id": "id_value", + "name": "name_value", + "args": {"fields": {}}, + }, + "function_response": { + "id": "id_value", + "name": "name_value", + "response": {}, + }, + "file_data": { + "mime_type": "mime_type_value", + "file_uri": "file_uri_value", + }, + "executable_code": {"language": 1, "code": "code_value"}, + "code_execution_result": {"outcome": 1, "output": "output_value"}, + } + ], + "role": "role_value", + }, + "contents": {}, + "tools": [ + { + "function_declarations": [ + { + "name": "name_value", + "description": "description_value", + "parameters": { + "type_": 1, + "format_": "format__value", + "description": "description_value", + "nullable": True, + "enum": ["enum_value1", "enum_value2"], + "items": {}, + "max_items": 967, + "min_items": 965, + "properties": {}, + "required": ["required_value1", "required_value2"], + }, + "response": {}, + } + ], + "google_search_retrieval": { + "dynamic_retrieval_config": {"mode": 1, "dynamic_threshold": 0.1809} + }, + "code_execution": {}, + "google_search": {}, + } + ], + "tool_config": { + "function_calling_config": { + "mode": 1, + "allowed_function_names": [ + "allowed_function_names_value1", + "allowed_function_names_value2", + ], + } + }, + "create_time": {}, + "update_time": {}, + "usage_metadata": {"total_token_count": 1836}, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = cache_service.CreateCachedContentRequest.meta.fields["cached_content"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["cached_content"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["cached_content"][field])): + del request_init["cached_content"][field][i][subfield] + else: + del request_init["cached_content"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = gag_cached_content.CachedContent( + name="name_value", + display_name="display_name_value", + model="model_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = gag_cached_content.CachedContent.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.create_cached_content(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_cached_content.CachedContent) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.model == "model_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_cached_content_rest_interceptors(null_interceptor): + transport = transports.CacheServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.CacheServiceRestInterceptor(), + ) + client = CacheServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.CacheServiceRestInterceptor, "post_create_cached_content" + ) as post, mock.patch.object( + transports.CacheServiceRestInterceptor, "pre_create_cached_content" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = cache_service.CreateCachedContentRequest.pb( + cache_service.CreateCachedContentRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = gag_cached_content.CachedContent.to_json( + gag_cached_content.CachedContent() + ) + req.return_value.content = return_value + + request = cache_service.CreateCachedContentRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = gag_cached_content.CachedContent() + + client.create_cached_content( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_cached_content_rest_bad_request( + request_type=cache_service.GetCachedContentRequest, +): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "cachedContents/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_cached_content(request) + + +@pytest.mark.parametrize( + "request_type", + [ + cache_service.GetCachedContentRequest, + dict, + ], +) +def test_get_cached_content_rest_call_success(request_type): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "cachedContents/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = cached_content.CachedContent( + name="name_value", + display_name="display_name_value", + model="model_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = cached_content.CachedContent.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.get_cached_content(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, cached_content.CachedContent) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.model == "model_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_cached_content_rest_interceptors(null_interceptor): + transport = transports.CacheServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.CacheServiceRestInterceptor(), + ) + client = CacheServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.CacheServiceRestInterceptor, "post_get_cached_content" + ) as post, mock.patch.object( + transports.CacheServiceRestInterceptor, "pre_get_cached_content" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = cache_service.GetCachedContentRequest.pb( + cache_service.GetCachedContentRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = cached_content.CachedContent.to_json( + cached_content.CachedContent() + ) + req.return_value.content = return_value + + request = cache_service.GetCachedContentRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = cached_content.CachedContent() + + client.get_cached_content( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_update_cached_content_rest_bad_request( + request_type=cache_service.UpdateCachedContentRequest, +): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"cached_content": {"name": "cachedContents/sample1"}} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.update_cached_content(request) + + +@pytest.mark.parametrize( + "request_type", + [ + cache_service.UpdateCachedContentRequest, + dict, + ], +) +def test_update_cached_content_rest_call_success(request_type): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"cached_content": {"name": "cachedContents/sample1"}} + request_init["cached_content"] = { + "expire_time": {"seconds": 751, "nanos": 543}, + "ttl": {"seconds": 751, "nanos": 543}, + "name": "cachedContents/sample1", + "display_name": "display_name_value", + "model": "model_value", + "system_instruction": { + "parts": [ + { + "text": "text_value", + "inline_data": { + "mime_type": "mime_type_value", + "data": b"data_blob", + }, + "function_call": { + "id": "id_value", + "name": "name_value", + "args": {"fields": {}}, + }, + "function_response": { + "id": "id_value", + "name": "name_value", + "response": {}, + }, + "file_data": { + "mime_type": "mime_type_value", + "file_uri": "file_uri_value", + }, + "executable_code": {"language": 1, "code": "code_value"}, + "code_execution_result": {"outcome": 1, "output": "output_value"}, + } + ], + "role": "role_value", + }, + "contents": {}, + "tools": [ + { + "function_declarations": [ + { + "name": "name_value", + "description": "description_value", + "parameters": { + "type_": 1, + "format_": "format__value", + "description": "description_value", + "nullable": True, + "enum": ["enum_value1", "enum_value2"], + "items": {}, + "max_items": 967, + "min_items": 965, + "properties": {}, + "required": ["required_value1", "required_value2"], + }, + "response": {}, + } + ], + "google_search_retrieval": { + "dynamic_retrieval_config": {"mode": 1, "dynamic_threshold": 0.1809} + }, + "code_execution": {}, + "google_search": {}, + } + ], + "tool_config": { + "function_calling_config": { + "mode": 1, + "allowed_function_names": [ + "allowed_function_names_value1", + "allowed_function_names_value2", + ], + } + }, + "create_time": {}, + "update_time": {}, + "usage_metadata": {"total_token_count": 1836}, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = cache_service.UpdateCachedContentRequest.meta.fields["cached_content"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["cached_content"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["cached_content"][field])): + del request_init["cached_content"][field][i][subfield] + else: + del request_init["cached_content"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = gag_cached_content.CachedContent( + name="name_value", + display_name="display_name_value", + model="model_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = gag_cached_content.CachedContent.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.update_cached_content(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_cached_content.CachedContent) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.model == "model_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_update_cached_content_rest_interceptors(null_interceptor): + transport = transports.CacheServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.CacheServiceRestInterceptor(), + ) + client = CacheServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.CacheServiceRestInterceptor, "post_update_cached_content" + ) as post, mock.patch.object( + transports.CacheServiceRestInterceptor, "pre_update_cached_content" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = cache_service.UpdateCachedContentRequest.pb( + cache_service.UpdateCachedContentRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = gag_cached_content.CachedContent.to_json( + gag_cached_content.CachedContent() + ) + req.return_value.content = return_value + + request = cache_service.UpdateCachedContentRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = gag_cached_content.CachedContent() + + client.update_cached_content( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_delete_cached_content_rest_bad_request( + request_type=cache_service.DeleteCachedContentRequest, +): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "cachedContents/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.delete_cached_content(request) + + +@pytest.mark.parametrize( + "request_type", + [ + cache_service.DeleteCachedContentRequest, + dict, + ], +) +def test_delete_cached_content_rest_call_success(request_type): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "cachedContents/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = "" + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.delete_cached_content(request) + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_delete_cached_content_rest_interceptors(null_interceptor): + transport = transports.CacheServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.CacheServiceRestInterceptor(), + ) + client = CacheServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.CacheServiceRestInterceptor, "pre_delete_cached_content" + ) as pre: + pre.assert_not_called() + pb_message = cache_service.DeleteCachedContentRequest.pb( + cache_service.DeleteCachedContentRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + request = cache_service.DeleteCachedContentRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.delete_cached_content( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + + +def test_get_operation_rest_bad_request( + request_type=operations_pb2.GetOperationRequest, +): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict( + {"name": "tunedModels/sample1/operations/sample2"}, request + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.GetOperationRequest, + dict, + ], +) +def test_get_operation_rest(request_type): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1/operations/sample2"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_operation(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_list_operations_rest_bad_request( + request_type=operations_pb2.ListOperationsRequest, +): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict({"name": "tunedModels/sample1"}, request) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_operations(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.ListOperationsRequest, + dict, + ], +) +def test_list_operations_rest(request_type): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.ListOperationsResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.list_operations(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_initialize_client_w_rest(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_cached_contents_empty_call_rest(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.list_cached_contents), "__call__" + ) as call: + client.list_cached_contents(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = cache_service.ListCachedContentsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_create_cached_content_empty_call_rest(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.create_cached_content), "__call__" + ) as call: + client.create_cached_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = cache_service.CreateCachedContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_cached_content_empty_call_rest(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.get_cached_content), "__call__" + ) as call: + client.get_cached_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = cache_service.GetCachedContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_update_cached_content_empty_call_rest(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.update_cached_content), "__call__" + ) as call: + client.update_cached_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = cache_service.UpdateCachedContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_delete_cached_content_empty_call_rest(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.delete_cached_content), "__call__" + ) as call: + client.delete_cached_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = cache_service.DeleteCachedContentRequest() + + assert args[0] == request_msg + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.CacheServiceGrpcTransport, + ) + + +def test_cache_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.CacheServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_cache_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.cache_service.transports.CacheServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.CacheServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "list_cached_contents", + "create_cached_content", + "get_cached_content", + "update_cached_content", + "delete_cached_content", + "get_operation", + "list_operations", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + "kind", + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_cache_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.ai.generativelanguage_v1alpha.services.cache_service.transports.CacheServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.CacheServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=(), + quota_project_id="octopus", + ) + + +def test_cache_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.ai.generativelanguage_v1alpha.services.cache_service.transports.CacheServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.CacheServiceTransport() + adc.assert_called_once() + + +def test_cache_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + CacheServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=(), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.CacheServiceGrpcTransport, + transports.CacheServiceGrpcAsyncIOTransport, + ], +) +def test_cache_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.CacheServiceGrpcTransport, + transports.CacheServiceGrpcAsyncIOTransport, + transports.CacheServiceRestTransport, + ], +) +def test_cache_service_transport_auth_gdch_credentials(transport_class): + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, "default", autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with(e) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.CacheServiceGrpcTransport, grpc_helpers), + (transports.CacheServiceGrpcAsyncIOTransport, grpc_helpers_async), + ], +) +def test_cache_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=(), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "transport_class", + [transports.CacheServiceGrpcTransport, transports.CacheServiceGrpcAsyncIOTransport], +) +def test_cache_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_cache_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.CacheServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_cache_service_host_no_port(transport_name): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_cache_service_host_with_port(transport_name): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com:8000" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_cache_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = CacheServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = CacheServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.list_cached_contents._session + session2 = client2.transport.list_cached_contents._session + assert session1 != session2 + session1 = client1.transport.create_cached_content._session + session2 = client2.transport.create_cached_content._session + assert session1 != session2 + session1 = client1.transport.get_cached_content._session + session2 = client2.transport.get_cached_content._session + assert session1 != session2 + session1 = client1.transport.update_cached_content._session + session2 = client2.transport.update_cached_content._session + assert session1 != session2 + session1 = client1.transport.delete_cached_content._session + session2 = client2.transport.delete_cached_content._session + assert session1 != session2 + + +def test_cache_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.CacheServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_cache_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.CacheServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [transports.CacheServiceGrpcTransport, transports.CacheServiceGrpcAsyncIOTransport], +) +def test_cache_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [transports.CacheServiceGrpcTransport, transports.CacheServiceGrpcAsyncIOTransport], +) +def test_cache_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_cached_content_path(): + id = "squid" + expected = "cachedContents/{id}".format( + id=id, + ) + actual = CacheServiceClient.cached_content_path(id) + assert expected == actual + + +def test_parse_cached_content_path(): + expected = { + "id": "clam", + } + path = CacheServiceClient.cached_content_path(**expected) + + # Check that the path construction is reversible. + actual = CacheServiceClient.parse_cached_content_path(path) + assert expected == actual + + +def test_model_path(): + model = "whelk" + expected = "models/{model}".format( + model=model, + ) + actual = CacheServiceClient.model_path(model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "model": "octopus", + } + path = CacheServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = CacheServiceClient.parse_model_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "oyster" + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = CacheServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "nudibranch", + } + path = CacheServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = CacheServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "cuttlefish" + expected = "folders/{folder}".format( + folder=folder, + ) + actual = CacheServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "mussel", + } + path = CacheServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = CacheServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "winkle" + expected = "organizations/{organization}".format( + organization=organization, + ) + actual = CacheServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nautilus", + } + path = CacheServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = CacheServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "scallop" + expected = "projects/{project}".format( + project=project, + ) + actual = CacheServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "abalone", + } + path = CacheServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = CacheServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "squid" + location = "clam" + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + actual = CacheServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "whelk", + "location": "octopus", + } + path = CacheServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = CacheServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.CacheServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.CacheServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = CacheServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + +def test_get_operation(transport: str = "grpc"): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + response = client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +@pytest.mark.asyncio +async def test_get_operation_async(transport: str = "grpc_asyncio"): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_get_operation_field_headers(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = operations_pb2.Operation() + + client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_operation_field_headers_async(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_get_operation_from_dict(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + + response = client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_operation_from_dict_async(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_list_operations(transport: str = "grpc"): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + response = client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +@pytest.mark.asyncio +async def test_list_operations_async(transport: str = "grpc_asyncio"): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_list_operations_field_headers(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = operations_pb2.ListOperationsResponse() + + client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_operations_field_headers_async(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_list_operations_from_dict(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + + response = client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_list_operations_from_dict_async(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_transport_close_grpc(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +@pytest.mark.asyncio +async def test_transport_close_grpc_asyncio(): + client = CacheServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close_rest(): + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + with mock.patch.object( + type(getattr(client.transport, "_session")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "rest", + "grpc", + ] + for transport in transports: + client = CacheServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (CacheServiceClient, transports.CacheServiceGrpcTransport), + (CacheServiceAsyncClient, transports.CacheServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_discuss_service.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_discuss_service.py new file mode 100644 index 000000000000..d6f690755511 --- /dev/null +++ b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_discuss_service.py @@ -0,0 +1,3756 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +from collections.abc import AsyncIterable, Iterable +import json +import math + +from google.api_core import api_core_version +from google.protobuf import json_format +import grpc +from grpc.experimental import aio +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule, TimestampRule +import pytest +from requests import PreparedRequest, Request, Response +from requests.sessions import Session + +try: + from google.auth.aio import credentials as ga_credentials_async + + HAS_GOOGLE_AUTH_AIO = True +except ImportError: # pragma: NO COVER + HAS_GOOGLE_AUTH_AIO = False + +from google.api_core import gapic_v1, grpc_helpers, grpc_helpers_async, path_template +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +import google.auth +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account + +from google.ai.generativelanguage_v1alpha.services.discuss_service import ( + DiscussServiceAsyncClient, + DiscussServiceClient, + transports, +) +from google.ai.generativelanguage_v1alpha.types import citation, discuss_service, safety + + +async def mock_async_gen(data, chunk_size=1): + for i in range(0, len(data)): # pragma: NO COVER + chunk = data[i : i + chunk_size] + yield chunk.encode("utf-8") + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# TODO: use async auth anon credentials by default once the minimum version of google-auth is upgraded. +# See related issue: https://github.com/googleapis/gapic-generator-python/issues/2107. +def async_anonymous_credentials(): + if HAS_GOOGLE_AUTH_AIO: + return ga_credentials_async.AnonymousCredentials() + return ga_credentials.AnonymousCredentials() + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +# If default endpoint template is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint template so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint_template(client): + return ( + "test.{UNIVERSE_DOMAIN}" + if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE) + else client._DEFAULT_ENDPOINT_TEMPLATE + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert DiscussServiceClient._get_default_mtls_endpoint(None) is None + assert ( + DiscussServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + DiscussServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + DiscussServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DiscussServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + DiscussServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + ) + + +def test__read_environment_variables(): + assert DiscussServiceClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert DiscussServiceClient._read_environment_variables() == ( + True, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + assert DiscussServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + DiscussServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + assert DiscussServiceClient._read_environment_variables() == ( + False, + "never", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + assert DiscussServiceClient._read_environment_variables() == ( + False, + "always", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + assert DiscussServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + DiscussServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}): + assert DiscussServiceClient._read_environment_variables() == ( + False, + "auto", + "foo.com", + ) + + +def test__get_client_cert_source(): + mock_provided_cert_source = mock.Mock() + mock_default_cert_source = mock.Mock() + + assert DiscussServiceClient._get_client_cert_source(None, False) is None + assert ( + DiscussServiceClient._get_client_cert_source(mock_provided_cert_source, False) + is None + ) + assert ( + DiscussServiceClient._get_client_cert_source(mock_provided_cert_source, True) + == mock_provided_cert_source + ) + + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", return_value=True + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_default_cert_source, + ): + assert ( + DiscussServiceClient._get_client_cert_source(None, True) + is mock_default_cert_source + ) + assert ( + DiscussServiceClient._get_client_cert_source( + mock_provided_cert_source, "true" + ) + is mock_provided_cert_source + ) + + +@mock.patch.object( + DiscussServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(DiscussServiceClient), +) +@mock.patch.object( + DiscussServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(DiscussServiceAsyncClient), +) +def test__get_api_endpoint(): + api_override = "foo.com" + mock_client_cert_source = mock.Mock() + default_universe = DiscussServiceClient._DEFAULT_UNIVERSE + default_endpoint = DiscussServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = DiscussServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + assert ( + DiscussServiceClient._get_api_endpoint( + api_override, mock_client_cert_source, default_universe, "always" + ) + == api_override + ) + assert ( + DiscussServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "auto" + ) + == DiscussServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + DiscussServiceClient._get_api_endpoint(None, None, default_universe, "auto") + == default_endpoint + ) + assert ( + DiscussServiceClient._get_api_endpoint(None, None, default_universe, "always") + == DiscussServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + DiscussServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "always" + ) + == DiscussServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + DiscussServiceClient._get_api_endpoint(None, None, mock_universe, "never") + == mock_endpoint + ) + assert ( + DiscussServiceClient._get_api_endpoint(None, None, default_universe, "never") + == default_endpoint + ) + + with pytest.raises(MutualTLSChannelError) as excinfo: + DiscussServiceClient._get_api_endpoint( + None, mock_client_cert_source, mock_universe, "auto" + ) + assert ( + str(excinfo.value) + == "mTLS is not supported in any universe other than googleapis.com." + ) + + +def test__get_universe_domain(): + client_universe_domain = "foo.com" + universe_domain_env = "bar.com" + + assert ( + DiscussServiceClient._get_universe_domain( + client_universe_domain, universe_domain_env + ) + == client_universe_domain + ) + assert ( + DiscussServiceClient._get_universe_domain(None, universe_domain_env) + == universe_domain_env + ) + assert ( + DiscussServiceClient._get_universe_domain(None, None) + == DiscussServiceClient._DEFAULT_UNIVERSE + ) + + with pytest.raises(ValueError) as excinfo: + DiscussServiceClient._get_universe_domain("", None) + assert str(excinfo.value) == "Universe Domain cannot be an empty string." + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (DiscussServiceClient, "grpc"), + (DiscussServiceAsyncClient, "grpc_asyncio"), + (DiscussServiceClient, "rest"), + ], +) +def test_discuss_service_client_from_service_account_info(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.DiscussServiceGrpcTransport, "grpc"), + (transports.DiscussServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.DiscussServiceRestTransport, "rest"), + ], +) +def test_discuss_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (DiscussServiceClient, "grpc"), + (DiscussServiceAsyncClient, "grpc_asyncio"), + (DiscussServiceClient, "rest"), + ], +) +def test_discuss_service_client_from_service_account_file(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +def test_discuss_service_client_get_transport_class(): + transport = DiscussServiceClient.get_transport_class() + available_transports = [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceRestTransport, + ] + assert transport in available_transports + + transport = DiscussServiceClient.get_transport_class("grpc") + assert transport == transports.DiscussServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc"), + ( + DiscussServiceAsyncClient, + transports.DiscussServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest"), + ], +) +@mock.patch.object( + DiscussServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(DiscussServiceClient), +) +@mock.patch.object( + DiscussServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(DiscussServiceAsyncClient), +) +def test_discuss_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(DiscussServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(DiscussServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", "true"), + ( + DiscussServiceAsyncClient, + transports.DiscussServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc", "false"), + ( + DiscussServiceAsyncClient, + transports.DiscussServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", "true"), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", "false"), + ], +) +@mock.patch.object( + DiscussServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(DiscussServiceClient), +) +@mock.patch.object( + DiscussServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(DiscussServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_discuss_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class", [DiscussServiceClient, DiscussServiceAsyncClient] +) +@mock.patch.object( + DiscussServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DiscussServiceClient), +) +@mock.patch.object( + DiscussServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(DiscussServiceAsyncClient), +) +def test_discuss_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + +@pytest.mark.parametrize( + "client_class", [DiscussServiceClient, DiscussServiceAsyncClient] +) +@mock.patch.object( + DiscussServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(DiscussServiceClient), +) +@mock.patch.object( + DiscussServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(DiscussServiceAsyncClient), +) +def test_discuss_service_client_client_api_endpoint(client_class): + mock_client_cert_source = client_cert_source_callback + api_override = "foo.com" + default_universe = DiscussServiceClient._DEFAULT_UNIVERSE + default_endpoint = DiscussServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = DiscussServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true", + # use ClientOptions.api_endpoint as the api endpoint regardless. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ): + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=api_override + ) + client = client_class( + client_options=options, + credentials=ga_credentials.AnonymousCredentials(), + ) + assert client.api_endpoint == api_override + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == default_endpoint + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always", + # use the DEFAULT_MTLS_ENDPOINT as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + + # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default), + # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist, + # and ClientOptions.universe_domain="bar.com", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint. + options = client_options.ClientOptions() + universe_exists = hasattr(options, "universe_domain") + if universe_exists: + options = client_options.ClientOptions(universe_domain=mock_universe) + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + else: + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == ( + mock_endpoint if universe_exists else default_endpoint + ) + assert client.universe_domain == ( + mock_universe if universe_exists else default_universe + ) + + # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + options = client_options.ClientOptions() + if hasattr(options, "universe_domain"): + delattr(options, "universe_domain") + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == default_endpoint + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport, "grpc"), + ( + DiscussServiceAsyncClient, + transports.DiscussServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest"), + ], +) +def test_discuss_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + DiscussServiceClient, + transports.DiscussServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + DiscussServiceAsyncClient, + transports.DiscussServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + (DiscussServiceClient, transports.DiscussServiceRestTransport, "rest", None), + ], +) +def test_discuss_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +def test_discuss_service_client_client_options_from_dict(): + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.discuss_service.transports.DiscussServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = DiscussServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + DiscussServiceClient, + transports.DiscussServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + DiscussServiceAsyncClient, + transports.DiscussServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_discuss_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=(), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + discuss_service.GenerateMessageRequest, + dict, + ], +) +def test_generate_message(request_type, transport: str = "grpc"): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.GenerateMessageResponse() + response = client.generate_message(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = discuss_service.GenerateMessageRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.GenerateMessageResponse) + + +def test_generate_message_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = discuss_service.GenerateMessageRequest( + model="model_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.generate_message(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == discuss_service.GenerateMessageRequest( + model="model_value", + ) + + +def test_generate_message_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.generate_message in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.generate_message + ] = mock_rpc + request = {} + client.generate_message(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.generate_message(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_generate_message_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.generate_message + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.generate_message + ] = mock_rpc + + request = {} + await client.generate_message(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.generate_message(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_generate_message_async( + transport: str = "grpc_asyncio", request_type=discuss_service.GenerateMessageRequest +): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + discuss_service.GenerateMessageResponse() + ) + response = await client.generate_message(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = discuss_service.GenerateMessageRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.GenerateMessageResponse) + + +@pytest.mark.asyncio +async def test_generate_message_async_from_dict(): + await test_generate_message_async(request_type=dict) + + +def test_generate_message_field_headers(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = discuss_service.GenerateMessageRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: + call.return_value = discuss_service.GenerateMessageResponse() + client.generate_message(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_generate_message_field_headers_async(): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = discuss_service.GenerateMessageRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + discuss_service.GenerateMessageResponse() + ) + await client.generate_message(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +def test_generate_message_flattened(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.GenerateMessageResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.generate_message( + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].prompt + mock_val = discuss_service.MessagePrompt(context="context_value") + assert arg == mock_val + assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) + arg = args[0].candidate_count + mock_val = 1573 + assert arg == mock_val + assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) + arg = args[0].top_k + mock_val = 541 + assert arg == mock_val + + +def test_generate_message_flattened_error(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_message( + discuss_service.GenerateMessageRequest(), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + + +@pytest.mark.asyncio +async def test_generate_message_flattened_async(): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.GenerateMessageResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + discuss_service.GenerateMessageResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.generate_message( + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].prompt + mock_val = discuss_service.MessagePrompt(context="context_value") + assert arg == mock_val + assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) + arg = args[0].candidate_count + mock_val = 1573 + assert arg == mock_val + assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) + arg = args[0].top_k + mock_val = 541 + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_generate_message_flattened_error_async(): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.generate_message( + discuss_service.GenerateMessageRequest(), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + + +@pytest.mark.parametrize( + "request_type", + [ + discuss_service.CountMessageTokensRequest, + dict, + ], +) +def test_count_message_tokens(request_type, transport: str = "grpc"): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.CountMessageTokensResponse( + token_count=1193, + ) + response = client.count_message_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = discuss_service.CountMessageTokensRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.CountMessageTokensResponse) + assert response.token_count == 1193 + + +def test_count_message_tokens_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = discuss_service.CountMessageTokensRequest( + model="model_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.count_message_tokens(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == discuss_service.CountMessageTokensRequest( + model="model_value", + ) + + +def test_count_message_tokens_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.count_message_tokens in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.count_message_tokens + ] = mock_rpc + request = {} + client.count_message_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.count_message_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_count_message_tokens_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.count_message_tokens + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.count_message_tokens + ] = mock_rpc + + request = {} + await client.count_message_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.count_message_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_count_message_tokens_async( + transport: str = "grpc_asyncio", + request_type=discuss_service.CountMessageTokensRequest, +): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + discuss_service.CountMessageTokensResponse( + token_count=1193, + ) + ) + response = await client.count_message_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = discuss_service.CountMessageTokensRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.CountMessageTokensResponse) + assert response.token_count == 1193 + + +@pytest.mark.asyncio +async def test_count_message_tokens_async_from_dict(): + await test_count_message_tokens_async(request_type=dict) + + +def test_count_message_tokens_field_headers(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = discuss_service.CountMessageTokensRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), "__call__" + ) as call: + call.return_value = discuss_service.CountMessageTokensResponse() + client.count_message_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_count_message_tokens_field_headers_async(): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = discuss_service.CountMessageTokensRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + discuss_service.CountMessageTokensResponse() + ) + await client.count_message_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +def test_count_message_tokens_flattened(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.CountMessageTokensResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.count_message_tokens( + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].prompt + mock_val = discuss_service.MessagePrompt(context="context_value") + assert arg == mock_val + + +def test_count_message_tokens_flattened_error(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.count_message_tokens( + discuss_service.CountMessageTokensRequest(), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), + ) + + +@pytest.mark.asyncio +async def test_count_message_tokens_flattened_async(): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = discuss_service.CountMessageTokensResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + discuss_service.CountMessageTokensResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.count_message_tokens( + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].prompt + mock_val = discuss_service.MessagePrompt(context="context_value") + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_count_message_tokens_flattened_error_async(): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.count_message_tokens( + discuss_service.CountMessageTokensRequest(), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), + ) + + +def test_generate_message_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.generate_message in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.generate_message + ] = mock_rpc + + request = {} + client.generate_message(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.generate_message(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_generate_message_rest_required_fields( + request_type=discuss_service.GenerateMessageRequest, +): + transport_class = transports.DiscussServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).generate_message._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = "model_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).generate_message._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == "model_value" + + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = discuss_service.GenerateMessageResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = discuss_service.GenerateMessageResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.generate_message(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_generate_message_rest_unset_required_fields(): + transport = transports.DiscussServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.generate_message._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "prompt", + ) + ) + ) + + +def test_generate_message_rest_flattened(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = discuss_service.GenerateMessageResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"model": "models/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = discuss_service.GenerateMessageResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.generate_message(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{model=models/*}:generateMessage" % client.transport._host, + args[1], + ) + + +def test_generate_message_rest_flattened_error(transport: str = "rest"): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_message( + discuss_service.GenerateMessageRequest(), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), + temperature=0.1198, + candidate_count=1573, + top_p=0.546, + top_k=541, + ) + + +def test_count_message_tokens_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.count_message_tokens in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.count_message_tokens + ] = mock_rpc + + request = {} + client.count_message_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.count_message_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_count_message_tokens_rest_required_fields( + request_type=discuss_service.CountMessageTokensRequest, +): + transport_class = transports.DiscussServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).count_message_tokens._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = "model_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).count_message_tokens._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == "model_value" + + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = discuss_service.CountMessageTokensResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = discuss_service.CountMessageTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.count_message_tokens(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_count_message_tokens_rest_unset_required_fields(): + transport = transports.DiscussServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.count_message_tokens._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "prompt", + ) + ) + ) + + +def test_count_message_tokens_rest_flattened(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = discuss_service.CountMessageTokensResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"model": "models/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = discuss_service.CountMessageTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.count_message_tokens(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{model=models/*}:countMessageTokens" % client.transport._host, + args[1], + ) + + +def test_count_message_tokens_rest_flattened_error(transport: str = "rest"): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.count_message_tokens( + discuss_service.CountMessageTokensRequest(), + model="model_value", + prompt=discuss_service.MessagePrompt(context="context_value"), + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DiscussServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = DiscussServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = DiscussServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = DiscussServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = DiscussServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.DiscussServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.DiscussServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + transports.DiscussServiceRestTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_kind_grpc(): + transport = DiscussServiceClient.get_transport_class("grpc")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "grpc" + + +def test_initialize_client_w_grpc(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_generate_message_empty_call_grpc(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: + call.return_value = discuss_service.GenerateMessageResponse() + client.generate_message(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = discuss_service.GenerateMessageRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_count_message_tokens_empty_call_grpc(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), "__call__" + ) as call: + call.return_value = discuss_service.CountMessageTokensResponse() + client.count_message_tokens(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = discuss_service.CountMessageTokensRequest() + + assert args[0] == request_msg + + +def test_transport_kind_grpc_asyncio(): + transport = DiscussServiceAsyncClient.get_transport_class("grpc_asyncio")( + credentials=async_anonymous_credentials() + ) + assert transport.kind == "grpc_asyncio" + + +def test_initialize_client_w_grpc_asyncio(): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_generate_message_empty_call_grpc_asyncio(): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + discuss_service.GenerateMessageResponse() + ) + await client.generate_message(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = discuss_service.GenerateMessageRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_count_message_tokens_empty_call_grpc_asyncio(): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + discuss_service.CountMessageTokensResponse( + token_count=1193, + ) + ) + await client.count_message_tokens(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = discuss_service.CountMessageTokensRequest() + + assert args[0] == request_msg + + +def test_transport_kind_rest(): + transport = DiscussServiceClient.get_transport_class("rest")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "rest" + + +def test_generate_message_rest_bad_request( + request_type=discuss_service.GenerateMessageRequest, +): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.generate_message(request) + + +@pytest.mark.parametrize( + "request_type", + [ + discuss_service.GenerateMessageRequest, + dict, + ], +) +def test_generate_message_rest_call_success(request_type): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = discuss_service.GenerateMessageResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = discuss_service.GenerateMessageResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.generate_message(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.GenerateMessageResponse) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_generate_message_rest_interceptors(null_interceptor): + transport = transports.DiscussServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DiscussServiceRestInterceptor(), + ) + client = DiscussServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DiscussServiceRestInterceptor, "post_generate_message" + ) as post, mock.patch.object( + transports.DiscussServiceRestInterceptor, "pre_generate_message" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = discuss_service.GenerateMessageRequest.pb( + discuss_service.GenerateMessageRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = discuss_service.GenerateMessageResponse.to_json( + discuss_service.GenerateMessageResponse() + ) + req.return_value.content = return_value + + request = discuss_service.GenerateMessageRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = discuss_service.GenerateMessageResponse() + + client.generate_message( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_count_message_tokens_rest_bad_request( + request_type=discuss_service.CountMessageTokensRequest, +): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.count_message_tokens(request) + + +@pytest.mark.parametrize( + "request_type", + [ + discuss_service.CountMessageTokensRequest, + dict, + ], +) +def test_count_message_tokens_rest_call_success(request_type): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = discuss_service.CountMessageTokensResponse( + token_count=1193, + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = discuss_service.CountMessageTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.count_message_tokens(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, discuss_service.CountMessageTokensResponse) + assert response.token_count == 1193 + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_count_message_tokens_rest_interceptors(null_interceptor): + transport = transports.DiscussServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.DiscussServiceRestInterceptor(), + ) + client = DiscussServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.DiscussServiceRestInterceptor, "post_count_message_tokens" + ) as post, mock.patch.object( + transports.DiscussServiceRestInterceptor, "pre_count_message_tokens" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = discuss_service.CountMessageTokensRequest.pb( + discuss_service.CountMessageTokensRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = discuss_service.CountMessageTokensResponse.to_json( + discuss_service.CountMessageTokensResponse() + ) + req.return_value.content = return_value + + request = discuss_service.CountMessageTokensRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = discuss_service.CountMessageTokensResponse() + + client.count_message_tokens( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_operation_rest_bad_request( + request_type=operations_pb2.GetOperationRequest, +): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict( + {"name": "tunedModels/sample1/operations/sample2"}, request + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.GetOperationRequest, + dict, + ], +) +def test_get_operation_rest(request_type): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1/operations/sample2"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_operation(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_list_operations_rest_bad_request( + request_type=operations_pb2.ListOperationsRequest, +): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict({"name": "tunedModels/sample1"}, request) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_operations(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.ListOperationsRequest, + dict, + ], +) +def test_list_operations_rest(request_type): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.ListOperationsResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.list_operations(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_initialize_client_w_rest(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_generate_message_empty_call_rest(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.generate_message), "__call__") as call: + client.generate_message(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = discuss_service.GenerateMessageRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_count_message_tokens_empty_call_rest(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.count_message_tokens), "__call__" + ) as call: + client.count_message_tokens(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = discuss_service.CountMessageTokensRequest() + + assert args[0] == request_msg + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.DiscussServiceGrpcTransport, + ) + + +def test_discuss_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.DiscussServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_discuss_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.discuss_service.transports.DiscussServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.DiscussServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "generate_message", + "count_message_tokens", + "get_operation", + "list_operations", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + "kind", + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_discuss_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.ai.generativelanguage_v1alpha.services.discuss_service.transports.DiscussServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.DiscussServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=(), + quota_project_id="octopus", + ) + + +def test_discuss_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.ai.generativelanguage_v1alpha.services.discuss_service.transports.DiscussServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.DiscussServiceTransport() + adc.assert_called_once() + + +def test_discuss_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + DiscussServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=(), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + ], +) +def test_discuss_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + transports.DiscussServiceRestTransport, + ], +) +def test_discuss_service_transport_auth_gdch_credentials(transport_class): + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, "default", autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with(e) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.DiscussServiceGrpcTransport, grpc_helpers), + (transports.DiscussServiceGrpcAsyncIOTransport, grpc_helpers_async), + ], +) +def test_discuss_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=(), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + ], +) +def test_discuss_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_discuss_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.DiscussServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_discuss_service_host_no_port(transport_name): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_discuss_service_host_with_port(transport_name): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com:8000" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_discuss_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = DiscussServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = DiscussServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.generate_message._session + session2 = client2.transport.generate_message._session + assert session1 != session2 + session1 = client1.transport.count_message_tokens._session + session2 = client2.transport.count_message_tokens._session + assert session1 != session2 + + +def test_discuss_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.DiscussServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_discuss_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.DiscussServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + ], +) +def test_discuss_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.DiscussServiceGrpcTransport, + transports.DiscussServiceGrpcAsyncIOTransport, + ], +) +def test_discuss_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_model_path(): + model = "squid" + expected = "models/{model}".format( + model=model, + ) + actual = DiscussServiceClient.model_path(model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "model": "clam", + } + path = DiscussServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_model_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "whelk" + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = DiscussServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "octopus", + } + path = DiscussServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "oyster" + expected = "folders/{folder}".format( + folder=folder, + ) + actual = DiscussServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nudibranch", + } + path = DiscussServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "cuttlefish" + expected = "organizations/{organization}".format( + organization=organization, + ) + actual = DiscussServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "mussel", + } + path = DiscussServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "winkle" + expected = "projects/{project}".format( + project=project, + ) + actual = DiscussServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "nautilus", + } + path = DiscussServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "scallop" + location = "abalone" + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + actual = DiscussServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "squid", + "location": "clam", + } + path = DiscussServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = DiscussServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.DiscussServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.DiscussServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = DiscussServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + +def test_get_operation(transport: str = "grpc"): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + response = client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +@pytest.mark.asyncio +async def test_get_operation_async(transport: str = "grpc_asyncio"): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_get_operation_field_headers(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = operations_pb2.Operation() + + client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_operation_field_headers_async(): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_get_operation_from_dict(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + + response = client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_operation_from_dict_async(): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_list_operations(transport: str = "grpc"): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + response = client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +@pytest.mark.asyncio +async def test_list_operations_async(transport: str = "grpc_asyncio"): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_list_operations_field_headers(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = operations_pb2.ListOperationsResponse() + + client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_operations_field_headers_async(): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_list_operations_from_dict(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + + response = client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_list_operations_from_dict_async(): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_transport_close_grpc(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +@pytest.mark.asyncio +async def test_transport_close_grpc_asyncio(): + client = DiscussServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close_rest(): + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + with mock.patch.object( + type(getattr(client.transport, "_session")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "rest", + "grpc", + ] + for transport in transports: + client = DiscussServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (DiscussServiceClient, transports.DiscussServiceGrpcTransport), + (DiscussServiceAsyncClient, transports.DiscussServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_file_service.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_file_service.py new file mode 100644 index 000000000000..77b6b85a9506 --- /dev/null +++ b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_file_service.py @@ -0,0 +1,4644 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +from collections.abc import AsyncIterable, Iterable +import json +import math + +from google.api_core import api_core_version +from google.protobuf import json_format +import grpc +from grpc.experimental import aio +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule, TimestampRule +import pytest +from requests import PreparedRequest, Request, Response +from requests.sessions import Session + +try: + from google.auth.aio import credentials as ga_credentials_async + + HAS_GOOGLE_AUTH_AIO = True +except ImportError: # pragma: NO COVER + HAS_GOOGLE_AUTH_AIO = False + +from google.api_core import gapic_v1, grpc_helpers, grpc_helpers_async, path_template +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +import google.auth +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account +from google.protobuf import any_pb2 # type: ignore +from google.protobuf import duration_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore +from google.rpc import status_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.file_service import ( + FileServiceAsyncClient, + FileServiceClient, + pagers, + transports, +) +from google.ai.generativelanguage_v1alpha.types import file, file_service + + +async def mock_async_gen(data, chunk_size=1): + for i in range(0, len(data)): # pragma: NO COVER + chunk = data[i : i + chunk_size] + yield chunk.encode("utf-8") + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# TODO: use async auth anon credentials by default once the minimum version of google-auth is upgraded. +# See related issue: https://github.com/googleapis/gapic-generator-python/issues/2107. +def async_anonymous_credentials(): + if HAS_GOOGLE_AUTH_AIO: + return ga_credentials_async.AnonymousCredentials() + return ga_credentials.AnonymousCredentials() + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +# If default endpoint template is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint template so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint_template(client): + return ( + "test.{UNIVERSE_DOMAIN}" + if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE) + else client._DEFAULT_ENDPOINT_TEMPLATE + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert FileServiceClient._get_default_mtls_endpoint(None) is None + assert ( + FileServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + FileServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + FileServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + FileServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert FileServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +def test__read_environment_variables(): + assert FileServiceClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert FileServiceClient._read_environment_variables() == (True, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + assert FileServiceClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + FileServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + assert FileServiceClient._read_environment_variables() == (False, "never", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + assert FileServiceClient._read_environment_variables() == ( + False, + "always", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + assert FileServiceClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + FileServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}): + assert FileServiceClient._read_environment_variables() == ( + False, + "auto", + "foo.com", + ) + + +def test__get_client_cert_source(): + mock_provided_cert_source = mock.Mock() + mock_default_cert_source = mock.Mock() + + assert FileServiceClient._get_client_cert_source(None, False) is None + assert ( + FileServiceClient._get_client_cert_source(mock_provided_cert_source, False) + is None + ) + assert ( + FileServiceClient._get_client_cert_source(mock_provided_cert_source, True) + == mock_provided_cert_source + ) + + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", return_value=True + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_default_cert_source, + ): + assert ( + FileServiceClient._get_client_cert_source(None, True) + is mock_default_cert_source + ) + assert ( + FileServiceClient._get_client_cert_source( + mock_provided_cert_source, "true" + ) + is mock_provided_cert_source + ) + + +@mock.patch.object( + FileServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(FileServiceClient), +) +@mock.patch.object( + FileServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(FileServiceAsyncClient), +) +def test__get_api_endpoint(): + api_override = "foo.com" + mock_client_cert_source = mock.Mock() + default_universe = FileServiceClient._DEFAULT_UNIVERSE + default_endpoint = FileServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = FileServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + assert ( + FileServiceClient._get_api_endpoint( + api_override, mock_client_cert_source, default_universe, "always" + ) + == api_override + ) + assert ( + FileServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "auto" + ) + == FileServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + FileServiceClient._get_api_endpoint(None, None, default_universe, "auto") + == default_endpoint + ) + assert ( + FileServiceClient._get_api_endpoint(None, None, default_universe, "always") + == FileServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + FileServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "always" + ) + == FileServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + FileServiceClient._get_api_endpoint(None, None, mock_universe, "never") + == mock_endpoint + ) + assert ( + FileServiceClient._get_api_endpoint(None, None, default_universe, "never") + == default_endpoint + ) + + with pytest.raises(MutualTLSChannelError) as excinfo: + FileServiceClient._get_api_endpoint( + None, mock_client_cert_source, mock_universe, "auto" + ) + assert ( + str(excinfo.value) + == "mTLS is not supported in any universe other than googleapis.com." + ) + + +def test__get_universe_domain(): + client_universe_domain = "foo.com" + universe_domain_env = "bar.com" + + assert ( + FileServiceClient._get_universe_domain( + client_universe_domain, universe_domain_env + ) + == client_universe_domain + ) + assert ( + FileServiceClient._get_universe_domain(None, universe_domain_env) + == universe_domain_env + ) + assert ( + FileServiceClient._get_universe_domain(None, None) + == FileServiceClient._DEFAULT_UNIVERSE + ) + + with pytest.raises(ValueError) as excinfo: + FileServiceClient._get_universe_domain("", None) + assert str(excinfo.value) == "Universe Domain cannot be an empty string." + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (FileServiceClient, "grpc"), + (FileServiceAsyncClient, "grpc_asyncio"), + (FileServiceClient, "rest"), + ], +) +def test_file_service_client_from_service_account_info(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.FileServiceGrpcTransport, "grpc"), + (transports.FileServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.FileServiceRestTransport, "rest"), + ], +) +def test_file_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (FileServiceClient, "grpc"), + (FileServiceAsyncClient, "grpc_asyncio"), + (FileServiceClient, "rest"), + ], +) +def test_file_service_client_from_service_account_file(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +def test_file_service_client_get_transport_class(): + transport = FileServiceClient.get_transport_class() + available_transports = [ + transports.FileServiceGrpcTransport, + transports.FileServiceRestTransport, + ] + assert transport in available_transports + + transport = FileServiceClient.get_transport_class("grpc") + assert transport == transports.FileServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (FileServiceClient, transports.FileServiceGrpcTransport, "grpc"), + ( + FileServiceAsyncClient, + transports.FileServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (FileServiceClient, transports.FileServiceRestTransport, "rest"), + ], +) +@mock.patch.object( + FileServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(FileServiceClient), +) +@mock.patch.object( + FileServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(FileServiceAsyncClient), +) +def test_file_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(FileServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(FileServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (FileServiceClient, transports.FileServiceGrpcTransport, "grpc", "true"), + ( + FileServiceAsyncClient, + transports.FileServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (FileServiceClient, transports.FileServiceGrpcTransport, "grpc", "false"), + ( + FileServiceAsyncClient, + transports.FileServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + (FileServiceClient, transports.FileServiceRestTransport, "rest", "true"), + (FileServiceClient, transports.FileServiceRestTransport, "rest", "false"), + ], +) +@mock.patch.object( + FileServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(FileServiceClient), +) +@mock.patch.object( + FileServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(FileServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_file_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class", [FileServiceClient, FileServiceAsyncClient]) +@mock.patch.object( + FileServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(FileServiceClient) +) +@mock.patch.object( + FileServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(FileServiceAsyncClient), +) +def test_file_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + +@pytest.mark.parametrize("client_class", [FileServiceClient, FileServiceAsyncClient]) +@mock.patch.object( + FileServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(FileServiceClient), +) +@mock.patch.object( + FileServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(FileServiceAsyncClient), +) +def test_file_service_client_client_api_endpoint(client_class): + mock_client_cert_source = client_cert_source_callback + api_override = "foo.com" + default_universe = FileServiceClient._DEFAULT_UNIVERSE + default_endpoint = FileServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = FileServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true", + # use ClientOptions.api_endpoint as the api endpoint regardless. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ): + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=api_override + ) + client = client_class( + client_options=options, + credentials=ga_credentials.AnonymousCredentials(), + ) + assert client.api_endpoint == api_override + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == default_endpoint + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always", + # use the DEFAULT_MTLS_ENDPOINT as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + + # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default), + # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist, + # and ClientOptions.universe_domain="bar.com", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint. + options = client_options.ClientOptions() + universe_exists = hasattr(options, "universe_domain") + if universe_exists: + options = client_options.ClientOptions(universe_domain=mock_universe) + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + else: + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == ( + mock_endpoint if universe_exists else default_endpoint + ) + assert client.universe_domain == ( + mock_universe if universe_exists else default_universe + ) + + # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + options = client_options.ClientOptions() + if hasattr(options, "universe_domain"): + delattr(options, "universe_domain") + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == default_endpoint + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (FileServiceClient, transports.FileServiceGrpcTransport, "grpc"), + ( + FileServiceAsyncClient, + transports.FileServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (FileServiceClient, transports.FileServiceRestTransport, "rest"), + ], +) +def test_file_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + (FileServiceClient, transports.FileServiceGrpcTransport, "grpc", grpc_helpers), + ( + FileServiceAsyncClient, + transports.FileServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + (FileServiceClient, transports.FileServiceRestTransport, "rest", None), + ], +) +def test_file_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +def test_file_service_client_client_options_from_dict(): + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.file_service.transports.FileServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = FileServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + (FileServiceClient, transports.FileServiceGrpcTransport, "grpc", grpc_helpers), + ( + FileServiceAsyncClient, + transports.FileServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_file_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=(), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + file_service.CreateFileRequest, + dict, + ], +) +def test_create_file(request_type, transport: str = "grpc"): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = file_service.CreateFileResponse() + response = client.create_file(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = file_service.CreateFileRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, file_service.CreateFileResponse) + + +def test_create_file_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = file_service.CreateFileRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_file), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.create_file(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == file_service.CreateFileRequest() + + +def test_create_file_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_file in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_file] = mock_rpc + request = {} + client.create_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_file_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_file + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.create_file + ] = mock_rpc + + request = {} + await client.create_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.create_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_file_async( + transport: str = "grpc_asyncio", request_type=file_service.CreateFileRequest +): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + file_service.CreateFileResponse() + ) + response = await client.create_file(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = file_service.CreateFileRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, file_service.CreateFileResponse) + + +@pytest.mark.asyncio +async def test_create_file_async_from_dict(): + await test_create_file_async(request_type=dict) + + +@pytest.mark.parametrize( + "request_type", + [ + file_service.ListFilesRequest, + dict, + ], +) +def test_list_files(request_type, transport: str = "grpc"): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_files), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = file_service.ListFilesResponse( + next_page_token="next_page_token_value", + ) + response = client.list_files(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = file_service.ListFilesRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListFilesPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_files_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = file_service.ListFilesRequest( + page_token="page_token_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_files), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_files(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == file_service.ListFilesRequest( + page_token="page_token_value", + ) + + +def test_list_files_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_files in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_files] = mock_rpc + request = {} + client.list_files(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_files(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_files_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_files + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.list_files + ] = mock_rpc + + request = {} + await client.list_files(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.list_files(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_files_async( + transport: str = "grpc_asyncio", request_type=file_service.ListFilesRequest +): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_files), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + file_service.ListFilesResponse( + next_page_token="next_page_token_value", + ) + ) + response = await client.list_files(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = file_service.ListFilesRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListFilesAsyncPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_files_async_from_dict(): + await test_list_files_async(request_type=dict) + + +def test_list_files_pager(transport_name: str = "grpc"): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_files), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + file_service.ListFilesResponse( + files=[ + file.File(), + file.File(), + file.File(), + ], + next_page_token="abc", + ), + file_service.ListFilesResponse( + files=[], + next_page_token="def", + ), + file_service.ListFilesResponse( + files=[ + file.File(), + ], + next_page_token="ghi", + ), + file_service.ListFilesResponse( + files=[ + file.File(), + file.File(), + ], + ), + RuntimeError, + ) + + expected_metadata = () + retry = retries.Retry() + timeout = 5 + pager = client.list_files(request={}, retry=retry, timeout=timeout) + + assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, file.File) for i in results) + + +def test_list_files_pages(transport_name: str = "grpc"): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_files), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + file_service.ListFilesResponse( + files=[ + file.File(), + file.File(), + file.File(), + ], + next_page_token="abc", + ), + file_service.ListFilesResponse( + files=[], + next_page_token="def", + ), + file_service.ListFilesResponse( + files=[ + file.File(), + ], + next_page_token="ghi", + ), + file_service.ListFilesResponse( + files=[ + file.File(), + file.File(), + ], + ), + RuntimeError, + ) + pages = list(client.list_files(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_files_async_pager(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_files), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + file_service.ListFilesResponse( + files=[ + file.File(), + file.File(), + file.File(), + ], + next_page_token="abc", + ), + file_service.ListFilesResponse( + files=[], + next_page_token="def", + ), + file_service.ListFilesResponse( + files=[ + file.File(), + ], + next_page_token="ghi", + ), + file_service.ListFilesResponse( + files=[ + file.File(), + file.File(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_files( + request={}, + ) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, file.File) for i in responses) + + +@pytest.mark.asyncio +async def test_list_files_async_pages(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_files), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + file_service.ListFilesResponse( + files=[ + file.File(), + file.File(), + file.File(), + ], + next_page_token="abc", + ), + file_service.ListFilesResponse( + files=[], + next_page_token="def", + ), + file_service.ListFilesResponse( + files=[ + file.File(), + ], + next_page_token="ghi", + ), + file_service.ListFilesResponse( + files=[ + file.File(), + file.File(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_files(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + file_service.GetFileRequest, + dict, + ], +) +def test_get_file(request_type, transport: str = "grpc"): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = file.File( + name="name_value", + display_name="display_name_value", + mime_type="mime_type_value", + size_bytes=1089, + sha256_hash=b"sha256_hash_blob", + uri="uri_value", + state=file.File.State.PROCESSING, + ) + response = client.get_file(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = file_service.GetFileRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, file.File) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.mime_type == "mime_type_value" + assert response.size_bytes == 1089 + assert response.sha256_hash == b"sha256_hash_blob" + assert response.uri == "uri_value" + assert response.state == file.File.State.PROCESSING + + +def test_get_file_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = file_service.GetFileRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_file), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.get_file(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == file_service.GetFileRequest( + name="name_value", + ) + + +def test_get_file_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_file in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_file] = mock_rpc + request = {} + client.get_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_file_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_file + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.get_file + ] = mock_rpc + + request = {} + await client.get_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.get_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_file_async( + transport: str = "grpc_asyncio", request_type=file_service.GetFileRequest +): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + file.File( + name="name_value", + display_name="display_name_value", + mime_type="mime_type_value", + size_bytes=1089, + sha256_hash=b"sha256_hash_blob", + uri="uri_value", + state=file.File.State.PROCESSING, + ) + ) + response = await client.get_file(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = file_service.GetFileRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, file.File) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.mime_type == "mime_type_value" + assert response.size_bytes == 1089 + assert response.sha256_hash == b"sha256_hash_blob" + assert response.uri == "uri_value" + assert response.state == file.File.State.PROCESSING + + +@pytest.mark.asyncio +async def test_get_file_async_from_dict(): + await test_get_file_async(request_type=dict) + + +def test_get_file_field_headers(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = file_service.GetFileRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_file), "__call__") as call: + call.return_value = file.File() + client.get_file(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_file_field_headers_async(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = file_service.GetFileRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_file), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(file.File()) + await client.get_file(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_get_file_flattened(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = file.File() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_file( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_get_file_flattened_error(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_file( + file_service.GetFileRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_file_flattened_async(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = file.File() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(file.File()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_file( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_get_file_flattened_error_async(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_file( + file_service.GetFileRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + file_service.DeleteFileRequest, + dict, + ], +) +def test_delete_file(request_type, transport: str = "grpc"): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.delete_file(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = file_service.DeleteFileRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_file_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = file_service.DeleteFileRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_file), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.delete_file(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == file_service.DeleteFileRequest( + name="name_value", + ) + + +def test_delete_file_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_file in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_file] = mock_rpc + request = {} + client.delete_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_file_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_file + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_file + ] = mock_rpc + + request = {} + await client.delete_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.delete_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_file_async( + transport: str = "grpc_asyncio", request_type=file_service.DeleteFileRequest +): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.delete_file(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = file_service.DeleteFileRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_delete_file_async_from_dict(): + await test_delete_file_async(request_type=dict) + + +def test_delete_file_field_headers(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = file_service.DeleteFileRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_file), "__call__") as call: + call.return_value = None + client.delete_file(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_file_field_headers_async(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = file_service.DeleteFileRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_file), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_file(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_delete_file_flattened(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_file( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_delete_file_flattened_error(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_file( + file_service.DeleteFileRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_file_flattened_async(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_file( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_delete_file_flattened_error_async(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_file( + file_service.DeleteFileRequest(), + name="name_value", + ) + + +def test_create_file_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_file in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_file] = mock_rpc + + request = {} + client.create_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_list_files_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_files in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_files] = mock_rpc + + request = {} + client.list_files(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_files(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_list_files_rest_pager(transport: str = "rest"): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + file_service.ListFilesResponse( + files=[ + file.File(), + file.File(), + file.File(), + ], + next_page_token="abc", + ), + file_service.ListFilesResponse( + files=[], + next_page_token="def", + ), + file_service.ListFilesResponse( + files=[ + file.File(), + ], + next_page_token="ghi", + ), + file_service.ListFilesResponse( + files=[ + file.File(), + file.File(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(file_service.ListFilesResponse.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {} + + pager = client.list_files(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, file.File) for i in results) + + pages = list(client.list_files(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_get_file_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_file in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_file] = mock_rpc + + request = {} + client.get_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_get_file_rest_required_fields(request_type=file_service.GetFileRequest): + transport_class = transports.FileServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_file._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_file._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = file.File() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = file.File.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_file(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_file_rest_unset_required_fields(): + transport = transports.FileServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_file._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +def test_get_file_rest_flattened(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = file.File() + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "files/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = file.File.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.get_file(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{name=files/*}" % client.transport._host, args[1] + ) + + +def test_get_file_rest_flattened_error(transport: str = "rest"): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_file( + file_service.GetFileRequest(), + name="name_value", + ) + + +def test_delete_file_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_file in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_file] = mock_rpc + + request = {} + client.delete_file(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_file(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_delete_file_rest_required_fields(request_type=file_service.DeleteFileRequest): + transport_class = transports.FileServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_file._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_file._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.delete_file(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_delete_file_rest_unset_required_fields(): + transport = transports.FileServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.delete_file._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +def test_delete_file_rest_flattened(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "files/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.delete_file(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{name=files/*}" % client.transport._host, args[1] + ) + + +def test_delete_file_rest_flattened_error(transport: str = "rest"): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_file( + file_service.DeleteFileRequest(), + name="name_value", + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.FileServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.FileServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = FileServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.FileServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = FileServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = FileServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.FileServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = FileServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.FileServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = FileServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.FileServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.FileServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.FileServiceGrpcTransport, + transports.FileServiceGrpcAsyncIOTransport, + transports.FileServiceRestTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_kind_grpc(): + transport = FileServiceClient.get_transport_class("grpc")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "grpc" + + +def test_initialize_client_w_grpc(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_create_file_empty_call_grpc(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.create_file), "__call__") as call: + call.return_value = file_service.CreateFileResponse() + client.create_file(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = file_service.CreateFileRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_files_empty_call_grpc(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_files), "__call__") as call: + call.return_value = file_service.ListFilesResponse() + client.list_files(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = file_service.ListFilesRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_file_empty_call_grpc(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_file), "__call__") as call: + call.return_value = file.File() + client.get_file(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = file_service.GetFileRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_delete_file_empty_call_grpc(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.delete_file), "__call__") as call: + call.return_value = None + client.delete_file(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = file_service.DeleteFileRequest() + + assert args[0] == request_msg + + +def test_transport_kind_grpc_asyncio(): + transport = FileServiceAsyncClient.get_transport_class("grpc_asyncio")( + credentials=async_anonymous_credentials() + ) + assert transport.kind == "grpc_asyncio" + + +def test_initialize_client_w_grpc_asyncio(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_create_file_empty_call_grpc_asyncio(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.create_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + file_service.CreateFileResponse() + ) + await client.create_file(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = file_service.CreateFileRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_list_files_empty_call_grpc_asyncio(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_files), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + file_service.ListFilesResponse( + next_page_token="next_page_token_value", + ) + ) + await client.list_files(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = file_service.ListFilesRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_get_file_empty_call_grpc_asyncio(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + file.File( + name="name_value", + display_name="display_name_value", + mime_type="mime_type_value", + size_bytes=1089, + sha256_hash=b"sha256_hash_blob", + uri="uri_value", + state=file.File.State.PROCESSING, + ) + ) + await client.get_file(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = file_service.GetFileRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_delete_file_empty_call_grpc_asyncio(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.delete_file), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_file(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = file_service.DeleteFileRequest() + + assert args[0] == request_msg + + +def test_transport_kind_rest(): + transport = FileServiceClient.get_transport_class("rest")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "rest" + + +def test_create_file_rest_bad_request(request_type=file_service.CreateFileRequest): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.create_file(request) + + +@pytest.mark.parametrize( + "request_type", + [ + file_service.CreateFileRequest, + dict, + ], +) +def test_create_file_rest_call_success(request_type): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = file_service.CreateFileResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = file_service.CreateFileResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.create_file(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, file_service.CreateFileResponse) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_file_rest_interceptors(null_interceptor): + transport = transports.FileServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.FileServiceRestInterceptor(), + ) + client = FileServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.FileServiceRestInterceptor, "post_create_file" + ) as post, mock.patch.object( + transports.FileServiceRestInterceptor, "pre_create_file" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = file_service.CreateFileRequest.pb(file_service.CreateFileRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = file_service.CreateFileResponse.to_json( + file_service.CreateFileResponse() + ) + req.return_value.content = return_value + + request = file_service.CreateFileRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = file_service.CreateFileResponse() + + client.create_file( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_files_rest_bad_request(request_type=file_service.ListFilesRequest): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_files(request) + + +@pytest.mark.parametrize( + "request_type", + [ + file_service.ListFilesRequest, + dict, + ], +) +def test_list_files_rest_call_success(request_type): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = file_service.ListFilesResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = file_service.ListFilesResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.list_files(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListFilesPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_files_rest_interceptors(null_interceptor): + transport = transports.FileServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.FileServiceRestInterceptor(), + ) + client = FileServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.FileServiceRestInterceptor, "post_list_files" + ) as post, mock.patch.object( + transports.FileServiceRestInterceptor, "pre_list_files" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = file_service.ListFilesRequest.pb(file_service.ListFilesRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = file_service.ListFilesResponse.to_json( + file_service.ListFilesResponse() + ) + req.return_value.content = return_value + + request = file_service.ListFilesRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = file_service.ListFilesResponse() + + client.list_files( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_file_rest_bad_request(request_type=file_service.GetFileRequest): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "files/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_file(request) + + +@pytest.mark.parametrize( + "request_type", + [ + file_service.GetFileRequest, + dict, + ], +) +def test_get_file_rest_call_success(request_type): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "files/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = file.File( + name="name_value", + display_name="display_name_value", + mime_type="mime_type_value", + size_bytes=1089, + sha256_hash=b"sha256_hash_blob", + uri="uri_value", + state=file.File.State.PROCESSING, + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = file.File.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.get_file(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, file.File) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.mime_type == "mime_type_value" + assert response.size_bytes == 1089 + assert response.sha256_hash == b"sha256_hash_blob" + assert response.uri == "uri_value" + assert response.state == file.File.State.PROCESSING + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_file_rest_interceptors(null_interceptor): + transport = transports.FileServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.FileServiceRestInterceptor(), + ) + client = FileServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.FileServiceRestInterceptor, "post_get_file" + ) as post, mock.patch.object( + transports.FileServiceRestInterceptor, "pre_get_file" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = file_service.GetFileRequest.pb(file_service.GetFileRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = file.File.to_json(file.File()) + req.return_value.content = return_value + + request = file_service.GetFileRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = file.File() + + client.get_file( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_delete_file_rest_bad_request(request_type=file_service.DeleteFileRequest): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "files/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.delete_file(request) + + +@pytest.mark.parametrize( + "request_type", + [ + file_service.DeleteFileRequest, + dict, + ], +) +def test_delete_file_rest_call_success(request_type): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "files/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = "" + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.delete_file(request) + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_delete_file_rest_interceptors(null_interceptor): + transport = transports.FileServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.FileServiceRestInterceptor(), + ) + client = FileServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.FileServiceRestInterceptor, "pre_delete_file" + ) as pre: + pre.assert_not_called() + pb_message = file_service.DeleteFileRequest.pb(file_service.DeleteFileRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + request = file_service.DeleteFileRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.delete_file( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + + +def test_get_operation_rest_bad_request( + request_type=operations_pb2.GetOperationRequest, +): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict( + {"name": "tunedModels/sample1/operations/sample2"}, request + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.GetOperationRequest, + dict, + ], +) +def test_get_operation_rest(request_type): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1/operations/sample2"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_operation(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_list_operations_rest_bad_request( + request_type=operations_pb2.ListOperationsRequest, +): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict({"name": "tunedModels/sample1"}, request) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_operations(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.ListOperationsRequest, + dict, + ], +) +def test_list_operations_rest(request_type): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.ListOperationsResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.list_operations(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_initialize_client_w_rest(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_create_file_empty_call_rest(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.create_file), "__call__") as call: + client.create_file(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = file_service.CreateFileRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_files_empty_call_rest(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_files), "__call__") as call: + client.list_files(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = file_service.ListFilesRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_file_empty_call_rest(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_file), "__call__") as call: + client.get_file(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = file_service.GetFileRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_delete_file_empty_call_rest(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.delete_file), "__call__") as call: + client.delete_file(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = file_service.DeleteFileRequest() + + assert args[0] == request_msg + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.FileServiceGrpcTransport, + ) + + +def test_file_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.FileServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_file_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.file_service.transports.FileServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.FileServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_file", + "list_files", + "get_file", + "delete_file", + "get_operation", + "list_operations", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + "kind", + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_file_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.ai.generativelanguage_v1alpha.services.file_service.transports.FileServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.FileServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=(), + quota_project_id="octopus", + ) + + +def test_file_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.ai.generativelanguage_v1alpha.services.file_service.transports.FileServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.FileServiceTransport() + adc.assert_called_once() + + +def test_file_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + FileServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=(), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.FileServiceGrpcTransport, + transports.FileServiceGrpcAsyncIOTransport, + ], +) +def test_file_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.FileServiceGrpcTransport, + transports.FileServiceGrpcAsyncIOTransport, + transports.FileServiceRestTransport, + ], +) +def test_file_service_transport_auth_gdch_credentials(transport_class): + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, "default", autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with(e) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.FileServiceGrpcTransport, grpc_helpers), + (transports.FileServiceGrpcAsyncIOTransport, grpc_helpers_async), + ], +) +def test_file_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=(), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "transport_class", + [transports.FileServiceGrpcTransport, transports.FileServiceGrpcAsyncIOTransport], +) +def test_file_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_file_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.FileServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_file_service_host_no_port(transport_name): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_file_service_host_with_port(transport_name): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com:8000" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_file_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = FileServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = FileServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.create_file._session + session2 = client2.transport.create_file._session + assert session1 != session2 + session1 = client1.transport.list_files._session + session2 = client2.transport.list_files._session + assert session1 != session2 + session1 = client1.transport.get_file._session + session2 = client2.transport.get_file._session + assert session1 != session2 + session1 = client1.transport.delete_file._session + session2 = client2.transport.delete_file._session + assert session1 != session2 + + +def test_file_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.FileServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_file_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.FileServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [transports.FileServiceGrpcTransport, transports.FileServiceGrpcAsyncIOTransport], +) +def test_file_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [transports.FileServiceGrpcTransport, transports.FileServiceGrpcAsyncIOTransport], +) +def test_file_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_file_path(): + file = "squid" + expected = "files/{file}".format( + file=file, + ) + actual = FileServiceClient.file_path(file) + assert expected == actual + + +def test_parse_file_path(): + expected = { + "file": "clam", + } + path = FileServiceClient.file_path(**expected) + + # Check that the path construction is reversible. + actual = FileServiceClient.parse_file_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "whelk" + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = FileServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "octopus", + } + path = FileServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = FileServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "oyster" + expected = "folders/{folder}".format( + folder=folder, + ) + actual = FileServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nudibranch", + } + path = FileServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = FileServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "cuttlefish" + expected = "organizations/{organization}".format( + organization=organization, + ) + actual = FileServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "mussel", + } + path = FileServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = FileServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "winkle" + expected = "projects/{project}".format( + project=project, + ) + actual = FileServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "nautilus", + } + path = FileServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = FileServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "scallop" + location = "abalone" + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + actual = FileServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "squid", + "location": "clam", + } + path = FileServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = FileServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.FileServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.FileServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = FileServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + +def test_get_operation(transport: str = "grpc"): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + response = client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +@pytest.mark.asyncio +async def test_get_operation_async(transport: str = "grpc_asyncio"): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_get_operation_field_headers(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = operations_pb2.Operation() + + client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_operation_field_headers_async(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_get_operation_from_dict(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + + response = client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_operation_from_dict_async(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_list_operations(transport: str = "grpc"): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + response = client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +@pytest.mark.asyncio +async def test_list_operations_async(transport: str = "grpc_asyncio"): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_list_operations_field_headers(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = operations_pb2.ListOperationsResponse() + + client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_operations_field_headers_async(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_list_operations_from_dict(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + + response = client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_list_operations_from_dict_async(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_transport_close_grpc(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +@pytest.mark.asyncio +async def test_transport_close_grpc_asyncio(): + client = FileServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close_rest(): + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + with mock.patch.object( + type(getattr(client.transport, "_session")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "rest", + "grpc", + ] + for transport in transports: + client = FileServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (FileServiceClient, transports.FileServiceGrpcTransport), + (FileServiceAsyncClient, transports.FileServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_generative_service.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_generative_service.py new file mode 100644 index 000000000000..43f13f03055e --- /dev/null +++ b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_generative_service.py @@ -0,0 +1,6939 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +from collections.abc import AsyncIterable, Iterable +import json +import math + +from google.api_core import api_core_version +from google.protobuf import json_format +import grpc +from grpc.experimental import aio +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule, TimestampRule +import pytest +from requests import PreparedRequest, Request, Response +from requests.sessions import Session + +try: + from google.auth.aio import credentials as ga_credentials_async + + HAS_GOOGLE_AUTH_AIO = True +except ImportError: # pragma: NO COVER + HAS_GOOGLE_AUTH_AIO = False + +from google.api_core import gapic_v1, grpc_helpers, grpc_helpers_async, path_template +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +import google.auth +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account +from google.protobuf import struct_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.generative_service import ( + GenerativeServiceAsyncClient, + GenerativeServiceClient, + transports, +) +from google.ai.generativelanguage_v1alpha.types import ( + generative_service, + retriever, + safety, +) +from google.ai.generativelanguage_v1alpha.types import content +from google.ai.generativelanguage_v1alpha.types import content as gag_content + + +async def mock_async_gen(data, chunk_size=1): + for i in range(0, len(data)): # pragma: NO COVER + chunk = data[i : i + chunk_size] + yield chunk.encode("utf-8") + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# TODO: use async auth anon credentials by default once the minimum version of google-auth is upgraded. +# See related issue: https://github.com/googleapis/gapic-generator-python/issues/2107. +def async_anonymous_credentials(): + if HAS_GOOGLE_AUTH_AIO: + return ga_credentials_async.AnonymousCredentials() + return ga_credentials.AnonymousCredentials() + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +# If default endpoint template is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint template so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint_template(client): + return ( + "test.{UNIVERSE_DOMAIN}" + if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE) + else client._DEFAULT_ENDPOINT_TEMPLATE + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert GenerativeServiceClient._get_default_mtls_endpoint(None) is None + assert ( + GenerativeServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + GenerativeServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + GenerativeServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + GenerativeServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + GenerativeServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) + + +def test__read_environment_variables(): + assert GenerativeServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert GenerativeServiceClient._read_environment_variables() == ( + True, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + assert GenerativeServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + GenerativeServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + assert GenerativeServiceClient._read_environment_variables() == ( + False, + "never", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + assert GenerativeServiceClient._read_environment_variables() == ( + False, + "always", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + assert GenerativeServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + GenerativeServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}): + assert GenerativeServiceClient._read_environment_variables() == ( + False, + "auto", + "foo.com", + ) + + +def test__get_client_cert_source(): + mock_provided_cert_source = mock.Mock() + mock_default_cert_source = mock.Mock() + + assert GenerativeServiceClient._get_client_cert_source(None, False) is None + assert ( + GenerativeServiceClient._get_client_cert_source( + mock_provided_cert_source, False + ) + is None + ) + assert ( + GenerativeServiceClient._get_client_cert_source(mock_provided_cert_source, True) + == mock_provided_cert_source + ) + + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", return_value=True + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_default_cert_source, + ): + assert ( + GenerativeServiceClient._get_client_cert_source(None, True) + is mock_default_cert_source + ) + assert ( + GenerativeServiceClient._get_client_cert_source( + mock_provided_cert_source, "true" + ) + is mock_provided_cert_source + ) + + +@mock.patch.object( + GenerativeServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(GenerativeServiceClient), +) +@mock.patch.object( + GenerativeServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(GenerativeServiceAsyncClient), +) +def test__get_api_endpoint(): + api_override = "foo.com" + mock_client_cert_source = mock.Mock() + default_universe = GenerativeServiceClient._DEFAULT_UNIVERSE + default_endpoint = GenerativeServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = GenerativeServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + assert ( + GenerativeServiceClient._get_api_endpoint( + api_override, mock_client_cert_source, default_universe, "always" + ) + == api_override + ) + assert ( + GenerativeServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "auto" + ) + == GenerativeServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + GenerativeServiceClient._get_api_endpoint(None, None, default_universe, "auto") + == default_endpoint + ) + assert ( + GenerativeServiceClient._get_api_endpoint( + None, None, default_universe, "always" + ) + == GenerativeServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + GenerativeServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "always" + ) + == GenerativeServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + GenerativeServiceClient._get_api_endpoint(None, None, mock_universe, "never") + == mock_endpoint + ) + assert ( + GenerativeServiceClient._get_api_endpoint(None, None, default_universe, "never") + == default_endpoint + ) + + with pytest.raises(MutualTLSChannelError) as excinfo: + GenerativeServiceClient._get_api_endpoint( + None, mock_client_cert_source, mock_universe, "auto" + ) + assert ( + str(excinfo.value) + == "mTLS is not supported in any universe other than googleapis.com." + ) + + +def test__get_universe_domain(): + client_universe_domain = "foo.com" + universe_domain_env = "bar.com" + + assert ( + GenerativeServiceClient._get_universe_domain( + client_universe_domain, universe_domain_env + ) + == client_universe_domain + ) + assert ( + GenerativeServiceClient._get_universe_domain(None, universe_domain_env) + == universe_domain_env + ) + assert ( + GenerativeServiceClient._get_universe_domain(None, None) + == GenerativeServiceClient._DEFAULT_UNIVERSE + ) + + with pytest.raises(ValueError) as excinfo: + GenerativeServiceClient._get_universe_domain("", None) + assert str(excinfo.value) == "Universe Domain cannot be an empty string." + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (GenerativeServiceClient, "grpc"), + (GenerativeServiceAsyncClient, "grpc_asyncio"), + (GenerativeServiceClient, "rest"), + ], +) +def test_generative_service_client_from_service_account_info( + client_class, transport_name +): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.GenerativeServiceGrpcTransport, "grpc"), + (transports.GenerativeServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.GenerativeServiceRestTransport, "rest"), + ], +) +def test_generative_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (GenerativeServiceClient, "grpc"), + (GenerativeServiceAsyncClient, "grpc_asyncio"), + (GenerativeServiceClient, "rest"), + ], +) +def test_generative_service_client_from_service_account_file( + client_class, transport_name +): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +def test_generative_service_client_get_transport_class(): + transport = GenerativeServiceClient.get_transport_class() + available_transports = [ + transports.GenerativeServiceGrpcTransport, + transports.GenerativeServiceRestTransport, + ] + assert transport in available_transports + + transport = GenerativeServiceClient.get_transport_class("grpc") + assert transport == transports.GenerativeServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (GenerativeServiceClient, transports.GenerativeServiceGrpcTransport, "grpc"), + ( + GenerativeServiceAsyncClient, + transports.GenerativeServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (GenerativeServiceClient, transports.GenerativeServiceRestTransport, "rest"), + ], +) +@mock.patch.object( + GenerativeServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(GenerativeServiceClient), +) +@mock.patch.object( + GenerativeServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(GenerativeServiceAsyncClient), +) +def test_generative_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(GenerativeServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(GenerativeServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + GenerativeServiceClient, + transports.GenerativeServiceGrpcTransport, + "grpc", + "true", + ), + ( + GenerativeServiceAsyncClient, + transports.GenerativeServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + GenerativeServiceClient, + transports.GenerativeServiceGrpcTransport, + "grpc", + "false", + ), + ( + GenerativeServiceAsyncClient, + transports.GenerativeServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ( + GenerativeServiceClient, + transports.GenerativeServiceRestTransport, + "rest", + "true", + ), + ( + GenerativeServiceClient, + transports.GenerativeServiceRestTransport, + "rest", + "false", + ), + ], +) +@mock.patch.object( + GenerativeServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(GenerativeServiceClient), +) +@mock.patch.object( + GenerativeServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(GenerativeServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_generative_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class", [GenerativeServiceClient, GenerativeServiceAsyncClient] +) +@mock.patch.object( + GenerativeServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(GenerativeServiceClient), +) +@mock.patch.object( + GenerativeServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(GenerativeServiceAsyncClient), +) +def test_generative_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + +@pytest.mark.parametrize( + "client_class", [GenerativeServiceClient, GenerativeServiceAsyncClient] +) +@mock.patch.object( + GenerativeServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(GenerativeServiceClient), +) +@mock.patch.object( + GenerativeServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(GenerativeServiceAsyncClient), +) +def test_generative_service_client_client_api_endpoint(client_class): + mock_client_cert_source = client_cert_source_callback + api_override = "foo.com" + default_universe = GenerativeServiceClient._DEFAULT_UNIVERSE + default_endpoint = GenerativeServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = GenerativeServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true", + # use ClientOptions.api_endpoint as the api endpoint regardless. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ): + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=api_override + ) + client = client_class( + client_options=options, + credentials=ga_credentials.AnonymousCredentials(), + ) + assert client.api_endpoint == api_override + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == default_endpoint + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always", + # use the DEFAULT_MTLS_ENDPOINT as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + + # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default), + # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist, + # and ClientOptions.universe_domain="bar.com", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint. + options = client_options.ClientOptions() + universe_exists = hasattr(options, "universe_domain") + if universe_exists: + options = client_options.ClientOptions(universe_domain=mock_universe) + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + else: + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == ( + mock_endpoint if universe_exists else default_endpoint + ) + assert client.universe_domain == ( + mock_universe if universe_exists else default_universe + ) + + # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + options = client_options.ClientOptions() + if hasattr(options, "universe_domain"): + delattr(options, "universe_domain") + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == default_endpoint + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (GenerativeServiceClient, transports.GenerativeServiceGrpcTransport, "grpc"), + ( + GenerativeServiceAsyncClient, + transports.GenerativeServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (GenerativeServiceClient, transports.GenerativeServiceRestTransport, "rest"), + ], +) +def test_generative_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + GenerativeServiceClient, + transports.GenerativeServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + GenerativeServiceAsyncClient, + transports.GenerativeServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ( + GenerativeServiceClient, + transports.GenerativeServiceRestTransport, + "rest", + None, + ), + ], +) +def test_generative_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +def test_generative_service_client_client_options_from_dict(): + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.generative_service.transports.GenerativeServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = GenerativeServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + GenerativeServiceClient, + transports.GenerativeServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + GenerativeServiceAsyncClient, + transports.GenerativeServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_generative_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=(), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + generative_service.GenerateContentRequest, + dict, + ], +) +def test_generate_content(request_type, transport: str = "grpc"): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_content), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = generative_service.GenerateContentResponse( + model_version="model_version_value", + ) + response = client.generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = generative_service.GenerateContentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.GenerateContentResponse) + assert response.model_version == "model_version_value" + + +def test_generate_content_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = generative_service.GenerateContentRequest( + model="model_value", + cached_content="cached_content_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_content), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.generate_content(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == generative_service.GenerateContentRequest( + model="model_value", + cached_content="cached_content_value", + ) + + +def test_generate_content_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.generate_content in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.generate_content + ] = mock_rpc + request = {} + client.generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_generate_content_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.generate_content + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.generate_content + ] = mock_rpc + + request = {} + await client.generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_generate_content_async( + transport: str = "grpc_asyncio", + request_type=generative_service.GenerateContentRequest, +): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_content), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.GenerateContentResponse( + model_version="model_version_value", + ) + ) + response = await client.generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = generative_service.GenerateContentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.GenerateContentResponse) + assert response.model_version == "model_version_value" + + +@pytest.mark.asyncio +async def test_generate_content_async_from_dict(): + await test_generate_content_async(request_type=dict) + + +def test_generate_content_field_headers(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = generative_service.GenerateContentRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_content), "__call__") as call: + call.return_value = generative_service.GenerateContentResponse() + client.generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_generate_content_field_headers_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = generative_service.GenerateContentRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_content), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.GenerateContentResponse() + ) + await client.generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +def test_generate_content_flattened(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_content), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = generative_service.GenerateContentResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.generate_content( + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].contents + mock_val = [content.Content(parts=[content.Part(text="text_value")])] + assert arg == mock_val + + +def test_generate_content_flattened_error(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_content( + generative_service.GenerateContentRequest(), + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + + +@pytest.mark.asyncio +async def test_generate_content_flattened_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_content), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = generative_service.GenerateContentResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.GenerateContentResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.generate_content( + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].contents + mock_val = [content.Content(parts=[content.Part(text="text_value")])] + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_generate_content_flattened_error_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.generate_content( + generative_service.GenerateContentRequest(), + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + generative_service.GenerateAnswerRequest, + dict, + ], +) +def test_generate_answer(request_type, transport: str = "grpc"): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_answer), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = generative_service.GenerateAnswerResponse( + answerable_probability=0.234, + ) + response = client.generate_answer(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = generative_service.GenerateAnswerRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.GenerateAnswerResponse) + assert math.isclose(response.answerable_probability, 0.234, rel_tol=1e-6) + + +def test_generate_answer_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = generative_service.GenerateAnswerRequest( + model="model_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_answer), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.generate_answer(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == generative_service.GenerateAnswerRequest( + model="model_value", + ) + + +def test_generate_answer_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.generate_answer in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.generate_answer] = mock_rpc + request = {} + client.generate_answer(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.generate_answer(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_generate_answer_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.generate_answer + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.generate_answer + ] = mock_rpc + + request = {} + await client.generate_answer(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.generate_answer(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_generate_answer_async( + transport: str = "grpc_asyncio", + request_type=generative_service.GenerateAnswerRequest, +): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_answer), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.GenerateAnswerResponse( + answerable_probability=0.234, + ) + ) + response = await client.generate_answer(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = generative_service.GenerateAnswerRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.GenerateAnswerResponse) + assert math.isclose(response.answerable_probability, 0.234, rel_tol=1e-6) + + +@pytest.mark.asyncio +async def test_generate_answer_async_from_dict(): + await test_generate_answer_async(request_type=dict) + + +def test_generate_answer_field_headers(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = generative_service.GenerateAnswerRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_answer), "__call__") as call: + call.return_value = generative_service.GenerateAnswerResponse() + client.generate_answer(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_generate_answer_field_headers_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = generative_service.GenerateAnswerRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_answer), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.GenerateAnswerResponse() + ) + await client.generate_answer(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +def test_generate_answer_flattened(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_answer), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = generative_service.GenerateAnswerResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.generate_answer( + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + safety_settings=[ + safety.SafetySetting( + category=safety.HarmCategory.HARM_CATEGORY_DEROGATORY + ) + ], + answer_style=generative_service.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].contents + mock_val = [content.Content(parts=[content.Part(text="text_value")])] + assert arg == mock_val + arg = args[0].safety_settings + mock_val = [ + safety.SafetySetting(category=safety.HarmCategory.HARM_CATEGORY_DEROGATORY) + ] + assert arg == mock_val + arg = args[0].answer_style + mock_val = generative_service.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE + assert arg == mock_val + + +def test_generate_answer_flattened_error(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_answer( + generative_service.GenerateAnswerRequest(), + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + safety_settings=[ + safety.SafetySetting( + category=safety.HarmCategory.HARM_CATEGORY_DEROGATORY + ) + ], + answer_style=generative_service.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE, + ) + + +@pytest.mark.asyncio +async def test_generate_answer_flattened_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_answer), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = generative_service.GenerateAnswerResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.GenerateAnswerResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.generate_answer( + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + safety_settings=[ + safety.SafetySetting( + category=safety.HarmCategory.HARM_CATEGORY_DEROGATORY + ) + ], + answer_style=generative_service.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].contents + mock_val = [content.Content(parts=[content.Part(text="text_value")])] + assert arg == mock_val + arg = args[0].safety_settings + mock_val = [ + safety.SafetySetting(category=safety.HarmCategory.HARM_CATEGORY_DEROGATORY) + ] + assert arg == mock_val + arg = args[0].answer_style + mock_val = generative_service.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_generate_answer_flattened_error_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.generate_answer( + generative_service.GenerateAnswerRequest(), + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + safety_settings=[ + safety.SafetySetting( + category=safety.HarmCategory.HARM_CATEGORY_DEROGATORY + ) + ], + answer_style=generative_service.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE, + ) + + +@pytest.mark.parametrize( + "request_type", + [ + generative_service.GenerateContentRequest, + dict, + ], +) +def test_stream_generate_content(request_type, transport: str = "grpc"): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.stream_generate_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = iter([generative_service.GenerateContentResponse()]) + response = client.stream_generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = generative_service.GenerateContentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + for message in response: + assert isinstance(message, generative_service.GenerateContentResponse) + + +def test_stream_generate_content_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = generative_service.GenerateContentRequest( + model="model_value", + cached_content="cached_content_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.stream_generate_content), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.stream_generate_content(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == generative_service.GenerateContentRequest( + model="model_value", + cached_content="cached_content_value", + ) + + +def test_stream_generate_content_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.stream_generate_content + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.stream_generate_content + ] = mock_rpc + request = {} + client.stream_generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.stream_generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_stream_generate_content_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.stream_generate_content + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.stream_generate_content + ] = mock_rpc + + request = {} + await client.stream_generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.stream_generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_stream_generate_content_async( + transport: str = "grpc_asyncio", + request_type=generative_service.GenerateContentRequest, +): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.stream_generate_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[generative_service.GenerateContentResponse()] + ) + response = await client.stream_generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = generative_service.GenerateContentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + message = await response.read() + assert isinstance(message, generative_service.GenerateContentResponse) + + +@pytest.mark.asyncio +async def test_stream_generate_content_async_from_dict(): + await test_stream_generate_content_async(request_type=dict) + + +def test_stream_generate_content_field_headers(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = generative_service.GenerateContentRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.stream_generate_content), "__call__" + ) as call: + call.return_value = iter([generative_service.GenerateContentResponse()]) + client.stream_generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_stream_generate_content_field_headers_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = generative_service.GenerateContentRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.stream_generate_content), "__call__" + ) as call: + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[generative_service.GenerateContentResponse()] + ) + await client.stream_generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +def test_stream_generate_content_flattened(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.stream_generate_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = iter([generative_service.GenerateContentResponse()]) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.stream_generate_content( + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].contents + mock_val = [content.Content(parts=[content.Part(text="text_value")])] + assert arg == mock_val + + +def test_stream_generate_content_flattened_error(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.stream_generate_content( + generative_service.GenerateContentRequest(), + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + + +@pytest.mark.asyncio +async def test_stream_generate_content_flattened_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.stream_generate_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = iter([generative_service.GenerateContentResponse()]) + + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.stream_generate_content( + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].contents + mock_val = [content.Content(parts=[content.Part(text="text_value")])] + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_stream_generate_content_flattened_error_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.stream_generate_content( + generative_service.GenerateContentRequest(), + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + generative_service.EmbedContentRequest, + dict, + ], +) +def test_embed_content(request_type, transport: str = "grpc"): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.embed_content), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = generative_service.EmbedContentResponse() + response = client.embed_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = generative_service.EmbedContentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.EmbedContentResponse) + + +def test_embed_content_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = generative_service.EmbedContentRequest( + model="model_value", + title="title_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.embed_content), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.embed_content(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == generative_service.EmbedContentRequest( + model="model_value", + title="title_value", + ) + + +def test_embed_content_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.embed_content in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.embed_content] = mock_rpc + request = {} + client.embed_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.embed_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_embed_content_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.embed_content + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.embed_content + ] = mock_rpc + + request = {} + await client.embed_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.embed_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_embed_content_async( + transport: str = "grpc_asyncio", request_type=generative_service.EmbedContentRequest +): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.embed_content), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.EmbedContentResponse() + ) + response = await client.embed_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = generative_service.EmbedContentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.EmbedContentResponse) + + +@pytest.mark.asyncio +async def test_embed_content_async_from_dict(): + await test_embed_content_async(request_type=dict) + + +def test_embed_content_field_headers(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = generative_service.EmbedContentRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.embed_content), "__call__") as call: + call.return_value = generative_service.EmbedContentResponse() + client.embed_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_embed_content_field_headers_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = generative_service.EmbedContentRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.embed_content), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.EmbedContentResponse() + ) + await client.embed_content(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +def test_embed_content_flattened(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.embed_content), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = generative_service.EmbedContentResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.embed_content( + model="model_value", + content=gag_content.Content(parts=[gag_content.Part(text="text_value")]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].content + mock_val = gag_content.Content(parts=[gag_content.Part(text="text_value")]) + assert arg == mock_val + + +def test_embed_content_flattened_error(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.embed_content( + generative_service.EmbedContentRequest(), + model="model_value", + content=gag_content.Content(parts=[gag_content.Part(text="text_value")]), + ) + + +@pytest.mark.asyncio +async def test_embed_content_flattened_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.embed_content), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = generative_service.EmbedContentResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.EmbedContentResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.embed_content( + model="model_value", + content=gag_content.Content(parts=[gag_content.Part(text="text_value")]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].content + mock_val = gag_content.Content(parts=[gag_content.Part(text="text_value")]) + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_embed_content_flattened_error_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.embed_content( + generative_service.EmbedContentRequest(), + model="model_value", + content=gag_content.Content(parts=[gag_content.Part(text="text_value")]), + ) + + +@pytest.mark.parametrize( + "request_type", + [ + generative_service.BatchEmbedContentsRequest, + dict, + ], +) +def test_batch_embed_contents(request_type, transport: str = "grpc"): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_contents), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = generative_service.BatchEmbedContentsResponse() + response = client.batch_embed_contents(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = generative_service.BatchEmbedContentsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.BatchEmbedContentsResponse) + + +def test_batch_embed_contents_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = generative_service.BatchEmbedContentsRequest( + model="model_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_contents), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.batch_embed_contents(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == generative_service.BatchEmbedContentsRequest( + model="model_value", + ) + + +def test_batch_embed_contents_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_embed_contents in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_embed_contents + ] = mock_rpc + request = {} + client.batch_embed_contents(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_embed_contents(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_batch_embed_contents_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_embed_contents + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_embed_contents + ] = mock_rpc + + request = {} + await client.batch_embed_contents(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.batch_embed_contents(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_batch_embed_contents_async( + transport: str = "grpc_asyncio", + request_type=generative_service.BatchEmbedContentsRequest, +): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_contents), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.BatchEmbedContentsResponse() + ) + response = await client.batch_embed_contents(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = generative_service.BatchEmbedContentsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.BatchEmbedContentsResponse) + + +@pytest.mark.asyncio +async def test_batch_embed_contents_async_from_dict(): + await test_batch_embed_contents_async(request_type=dict) + + +def test_batch_embed_contents_field_headers(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = generative_service.BatchEmbedContentsRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_contents), "__call__" + ) as call: + call.return_value = generative_service.BatchEmbedContentsResponse() + client.batch_embed_contents(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_batch_embed_contents_field_headers_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = generative_service.BatchEmbedContentsRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_contents), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.BatchEmbedContentsResponse() + ) + await client.batch_embed_contents(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +def test_batch_embed_contents_flattened(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_contents), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = generative_service.BatchEmbedContentsResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.batch_embed_contents( + model="model_value", + requests=[generative_service.EmbedContentRequest(model="model_value")], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].requests + mock_val = [generative_service.EmbedContentRequest(model="model_value")] + assert arg == mock_val + + +def test_batch_embed_contents_flattened_error(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.batch_embed_contents( + generative_service.BatchEmbedContentsRequest(), + model="model_value", + requests=[generative_service.EmbedContentRequest(model="model_value")], + ) + + +@pytest.mark.asyncio +async def test_batch_embed_contents_flattened_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_contents), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = generative_service.BatchEmbedContentsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.BatchEmbedContentsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.batch_embed_contents( + model="model_value", + requests=[generative_service.EmbedContentRequest(model="model_value")], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].requests + mock_val = [generative_service.EmbedContentRequest(model="model_value")] + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_batch_embed_contents_flattened_error_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.batch_embed_contents( + generative_service.BatchEmbedContentsRequest(), + model="model_value", + requests=[generative_service.EmbedContentRequest(model="model_value")], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + generative_service.CountTokensRequest, + dict, + ], +) +def test_count_tokens(request_type, transport: str = "grpc"): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.count_tokens), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = generative_service.CountTokensResponse( + total_tokens=1303, + cached_content_token_count=2746, + ) + response = client.count_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = generative_service.CountTokensRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.CountTokensResponse) + assert response.total_tokens == 1303 + assert response.cached_content_token_count == 2746 + + +def test_count_tokens_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = generative_service.CountTokensRequest( + model="model_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.count_tokens), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.count_tokens(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == generative_service.CountTokensRequest( + model="model_value", + ) + + +def test_count_tokens_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.count_tokens in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.count_tokens] = mock_rpc + request = {} + client.count_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.count_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_count_tokens_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.count_tokens + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.count_tokens + ] = mock_rpc + + request = {} + await client.count_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.count_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_count_tokens_async( + transport: str = "grpc_asyncio", request_type=generative_service.CountTokensRequest +): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.count_tokens), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.CountTokensResponse( + total_tokens=1303, + cached_content_token_count=2746, + ) + ) + response = await client.count_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = generative_service.CountTokensRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.CountTokensResponse) + assert response.total_tokens == 1303 + assert response.cached_content_token_count == 2746 + + +@pytest.mark.asyncio +async def test_count_tokens_async_from_dict(): + await test_count_tokens_async(request_type=dict) + + +def test_count_tokens_field_headers(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = generative_service.CountTokensRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.count_tokens), "__call__") as call: + call.return_value = generative_service.CountTokensResponse() + client.count_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_count_tokens_field_headers_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = generative_service.CountTokensRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.count_tokens), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.CountTokensResponse() + ) + await client.count_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +def test_count_tokens_flattened(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.count_tokens), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = generative_service.CountTokensResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.count_tokens( + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].contents + mock_val = [content.Content(parts=[content.Part(text="text_value")])] + assert arg == mock_val + + +def test_count_tokens_flattened_error(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.count_tokens( + generative_service.CountTokensRequest(), + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + + +@pytest.mark.asyncio +async def test_count_tokens_flattened_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.count_tokens), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = generative_service.CountTokensResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.CountTokensResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.count_tokens( + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].contents + mock_val = [content.Content(parts=[content.Part(text="text_value")])] + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_count_tokens_flattened_error_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.count_tokens( + generative_service.CountTokensRequest(), + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + generative_service.BidiGenerateContentClientMessage, + dict, + ], +) +def test_bidi_generate_content(request_type, transport: str = "grpc"): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + requests = [request] + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.bidi_generate_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = iter( + [generative_service.BidiGenerateContentServerMessage()] + ) + response = client.bidi_generate_content(iter(requests)) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert next(args[0]) == request + + # Establish that the response is the type that we expect. + for message in response: + assert isinstance(message, generative_service.BidiGenerateContentServerMessage) + + +def test_bidi_generate_content_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.bidi_generate_content + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.bidi_generate_content + ] = mock_rpc + request = [{}] + client.bidi_generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.bidi_generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_bidi_generate_content_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.bidi_generate_content + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.bidi_generate_content + ] = mock_rpc + + request = [{}] + await client.bidi_generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.bidi_generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_bidi_generate_content_async( + transport: str = "grpc_asyncio", + request_type=generative_service.BidiGenerateContentClientMessage, +): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + requests = [request] + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.bidi_generate_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.StreamStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[generative_service.BidiGenerateContentServerMessage()] + ) + response = await client.bidi_generate_content(iter(requests)) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert next(args[0]) == request + + # Establish that the response is the type that we expect. + message = await response.read() + assert isinstance(message, generative_service.BidiGenerateContentServerMessage) + + +@pytest.mark.asyncio +async def test_bidi_generate_content_async_from_dict(): + await test_bidi_generate_content_async(request_type=dict) + + +def test_generate_content_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.generate_content in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.generate_content + ] = mock_rpc + + request = {} + client.generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_generate_content_rest_required_fields( + request_type=generative_service.GenerateContentRequest, +): + transport_class = transports.GenerativeServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).generate_content._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = "model_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).generate_content._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == "model_value" + + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = generative_service.GenerateContentResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = generative_service.GenerateContentResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.generate_content(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_generate_content_rest_unset_required_fields(): + transport = transports.GenerativeServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.generate_content._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "contents", + ) + ) + ) + + +def test_generate_content_rest_flattened(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = generative_service.GenerateContentResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"model": "models/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = generative_service.GenerateContentResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.generate_content(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{model=models/*}:generateContent" % client.transport._host, + args[1], + ) + + +def test_generate_content_rest_flattened_error(transport: str = "rest"): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_content( + generative_service.GenerateContentRequest(), + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + + +def test_generate_answer_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.generate_answer in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.generate_answer] = mock_rpc + + request = {} + client.generate_answer(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.generate_answer(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_generate_answer_rest_required_fields( + request_type=generative_service.GenerateAnswerRequest, +): + transport_class = transports.GenerativeServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).generate_answer._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = "model_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).generate_answer._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == "model_value" + + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = generative_service.GenerateAnswerResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = generative_service.GenerateAnswerResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.generate_answer(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_generate_answer_rest_unset_required_fields(): + transport = transports.GenerativeServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.generate_answer._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "contents", + "answerStyle", + ) + ) + ) + + +def test_generate_answer_rest_flattened(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = generative_service.GenerateAnswerResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"model": "models/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + safety_settings=[ + safety.SafetySetting( + category=safety.HarmCategory.HARM_CATEGORY_DEROGATORY + ) + ], + answer_style=generative_service.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE, + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = generative_service.GenerateAnswerResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.generate_answer(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{model=models/*}:generateAnswer" % client.transport._host, + args[1], + ) + + +def test_generate_answer_rest_flattened_error(transport: str = "rest"): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_answer( + generative_service.GenerateAnswerRequest(), + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + safety_settings=[ + safety.SafetySetting( + category=safety.HarmCategory.HARM_CATEGORY_DEROGATORY + ) + ], + answer_style=generative_service.GenerateAnswerRequest.AnswerStyle.ABSTRACTIVE, + ) + + +def test_stream_generate_content_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.stream_generate_content + in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.stream_generate_content + ] = mock_rpc + + request = {} + client.stream_generate_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.stream_generate_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_stream_generate_content_rest_required_fields( + request_type=generative_service.GenerateContentRequest, +): + transport_class = transports.GenerativeServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).stream_generate_content._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = "model_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).stream_generate_content._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == "model_value" + + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = generative_service.GenerateContentResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = generative_service.GenerateContentResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + json_return_value = "[{}]".format(json_return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + with mock.patch.object(response_value, "iter_content") as iter_content: + iter_content.return_value = iter(json_return_value) + response = client.stream_generate_content(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_stream_generate_content_rest_unset_required_fields(): + transport = transports.GenerativeServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.stream_generate_content._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "contents", + ) + ) + ) + + +def test_stream_generate_content_rest_flattened(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = generative_service.GenerateContentResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"model": "models/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = generative_service.GenerateContentResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + json_return_value = "[{}]".format(json_return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + with mock.patch.object(response_value, "iter_content") as iter_content: + iter_content.return_value = iter(json_return_value) + client.stream_generate_content(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{model=models/*}:streamGenerateContent" + % client.transport._host, + args[1], + ) + + +def test_stream_generate_content_rest_flattened_error(transport: str = "rest"): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.stream_generate_content( + generative_service.GenerateContentRequest(), + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + + +def test_embed_content_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.embed_content in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.embed_content] = mock_rpc + + request = {} + client.embed_content(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.embed_content(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_embed_content_rest_required_fields( + request_type=generative_service.EmbedContentRequest, +): + transport_class = transports.GenerativeServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).embed_content._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = "model_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).embed_content._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == "model_value" + + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = generative_service.EmbedContentResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = generative_service.EmbedContentResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.embed_content(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_embed_content_rest_unset_required_fields(): + transport = transports.GenerativeServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.embed_content._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "content", + ) + ) + ) + + +def test_embed_content_rest_flattened(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = generative_service.EmbedContentResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"model": "models/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + model="model_value", + content=gag_content.Content(parts=[gag_content.Part(text="text_value")]), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = generative_service.EmbedContentResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.embed_content(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{model=models/*}:embedContent" % client.transport._host, args[1] + ) + + +def test_embed_content_rest_flattened_error(transport: str = "rest"): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.embed_content( + generative_service.EmbedContentRequest(), + model="model_value", + content=gag_content.Content(parts=[gag_content.Part(text="text_value")]), + ) + + +def test_batch_embed_contents_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_embed_contents in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_embed_contents + ] = mock_rpc + + request = {} + client.batch_embed_contents(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_embed_contents(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_batch_embed_contents_rest_required_fields( + request_type=generative_service.BatchEmbedContentsRequest, +): + transport_class = transports.GenerativeServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).batch_embed_contents._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = "model_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).batch_embed_contents._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == "model_value" + + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = generative_service.BatchEmbedContentsResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = generative_service.BatchEmbedContentsResponse.pb( + return_value + ) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.batch_embed_contents(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_batch_embed_contents_rest_unset_required_fields(): + transport = transports.GenerativeServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.batch_embed_contents._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "requests", + ) + ) + ) + + +def test_batch_embed_contents_rest_flattened(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = generative_service.BatchEmbedContentsResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"model": "models/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + model="model_value", + requests=[generative_service.EmbedContentRequest(model="model_value")], + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = generative_service.BatchEmbedContentsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.batch_embed_contents(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{model=models/*}:batchEmbedContents" % client.transport._host, + args[1], + ) + + +def test_batch_embed_contents_rest_flattened_error(transport: str = "rest"): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.batch_embed_contents( + generative_service.BatchEmbedContentsRequest(), + model="model_value", + requests=[generative_service.EmbedContentRequest(model="model_value")], + ) + + +def test_count_tokens_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.count_tokens in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.count_tokens] = mock_rpc + + request = {} + client.count_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.count_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_count_tokens_rest_required_fields( + request_type=generative_service.CountTokensRequest, +): + transport_class = transports.GenerativeServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).count_tokens._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = "model_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).count_tokens._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == "model_value" + + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = generative_service.CountTokensResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = generative_service.CountTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.count_tokens(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_count_tokens_rest_unset_required_fields(): + transport = transports.GenerativeServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.count_tokens._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("model",))) + + +def test_count_tokens_rest_flattened(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = generative_service.CountTokensResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"model": "models/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = generative_service.CountTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.count_tokens(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{model=models/*}:countTokens" % client.transport._host, args[1] + ) + + +def test_count_tokens_rest_flattened_error(transport: str = "rest"): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.count_tokens( + generative_service.CountTokensRequest(), + model="model_value", + contents=[content.Content(parts=[content.Part(text="text_value")])], + ) + + +def test_bidi_generate_content_rest_no_http_options(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = generative_service.BidiGenerateContentClientMessage() + requests = [request] + with pytest.raises(RuntimeError): + client.bidi_generate_content(requests) + + +def test_bidi_generate_content_rest_error(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # Since a `google.api.http` annotation is required for using a rest transport + # method, this should error. + with pytest.raises(NotImplementedError) as not_implemented_error: + client.bidi_generate_content({}) + assert "Method BidiGenerateContent is not available over REST transport" in str( + not_implemented_error.value + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.GenerativeServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.GenerativeServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = GenerativeServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.GenerativeServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = GenerativeServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = GenerativeServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.GenerativeServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = GenerativeServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.GenerativeServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = GenerativeServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.GenerativeServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.GenerativeServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.GenerativeServiceGrpcTransport, + transports.GenerativeServiceGrpcAsyncIOTransport, + transports.GenerativeServiceRestTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_kind_grpc(): + transport = GenerativeServiceClient.get_transport_class("grpc")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "grpc" + + +def test_initialize_client_w_grpc(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_generate_content_empty_call_grpc(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.generate_content), "__call__") as call: + call.return_value = generative_service.GenerateContentResponse() + client.generate_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.GenerateContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_generate_answer_empty_call_grpc(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.generate_answer), "__call__") as call: + call.return_value = generative_service.GenerateAnswerResponse() + client.generate_answer(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.GenerateAnswerRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_stream_generate_content_empty_call_grpc(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.stream_generate_content), "__call__" + ) as call: + call.return_value = iter([generative_service.GenerateContentResponse()]) + client.stream_generate_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.GenerateContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_embed_content_empty_call_grpc(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.embed_content), "__call__") as call: + call.return_value = generative_service.EmbedContentResponse() + client.embed_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.EmbedContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_batch_embed_contents_empty_call_grpc(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_contents), "__call__" + ) as call: + call.return_value = generative_service.BatchEmbedContentsResponse() + client.batch_embed_contents(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.BatchEmbedContentsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_count_tokens_empty_call_grpc(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.count_tokens), "__call__") as call: + call.return_value = generative_service.CountTokensResponse() + client.count_tokens(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.CountTokensRequest() + + assert args[0] == request_msg + + +def test_transport_kind_grpc_asyncio(): + transport = GenerativeServiceAsyncClient.get_transport_class("grpc_asyncio")( + credentials=async_anonymous_credentials() + ) + assert transport.kind == "grpc_asyncio" + + +def test_initialize_client_w_grpc_asyncio(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_generate_content_empty_call_grpc_asyncio(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.generate_content), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.GenerateContentResponse( + model_version="model_version_value", + ) + ) + await client.generate_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.GenerateContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_generate_answer_empty_call_grpc_asyncio(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.generate_answer), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.GenerateAnswerResponse( + answerable_probability=0.234, + ) + ) + await client.generate_answer(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.GenerateAnswerRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_stream_generate_content_empty_call_grpc_asyncio(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.stream_generate_content), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = mock.Mock(aio.UnaryStreamCall, autospec=True) + call.return_value.read = mock.AsyncMock( + side_effect=[generative_service.GenerateContentResponse()] + ) + await client.stream_generate_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.GenerateContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_embed_content_empty_call_grpc_asyncio(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.embed_content), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.EmbedContentResponse() + ) + await client.embed_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.EmbedContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_batch_embed_contents_empty_call_grpc_asyncio(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_contents), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.BatchEmbedContentsResponse() + ) + await client.batch_embed_contents(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.BatchEmbedContentsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_count_tokens_empty_call_grpc_asyncio(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.count_tokens), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + generative_service.CountTokensResponse( + total_tokens=1303, + cached_content_token_count=2746, + ) + ) + await client.count_tokens(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.CountTokensRequest() + + assert args[0] == request_msg + + +def test_transport_kind_rest(): + transport = GenerativeServiceClient.get_transport_class("rest")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "rest" + + +def test_generate_content_rest_bad_request( + request_type=generative_service.GenerateContentRequest, +): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.generate_content(request) + + +@pytest.mark.parametrize( + "request_type", + [ + generative_service.GenerateContentRequest, + dict, + ], +) +def test_generate_content_rest_call_success(request_type): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = generative_service.GenerateContentResponse( + model_version="model_version_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = generative_service.GenerateContentResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.generate_content(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.GenerateContentResponse) + assert response.model_version == "model_version_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_generate_content_rest_interceptors(null_interceptor): + transport = transports.GenerativeServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.GenerativeServiceRestInterceptor(), + ) + client = GenerativeServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.GenerativeServiceRestInterceptor, "post_generate_content" + ) as post, mock.patch.object( + transports.GenerativeServiceRestInterceptor, "pre_generate_content" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = generative_service.GenerateContentRequest.pb( + generative_service.GenerateContentRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = generative_service.GenerateContentResponse.to_json( + generative_service.GenerateContentResponse() + ) + req.return_value.content = return_value + + request = generative_service.GenerateContentRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = generative_service.GenerateContentResponse() + + client.generate_content( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_generate_answer_rest_bad_request( + request_type=generative_service.GenerateAnswerRequest, +): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.generate_answer(request) + + +@pytest.mark.parametrize( + "request_type", + [ + generative_service.GenerateAnswerRequest, + dict, + ], +) +def test_generate_answer_rest_call_success(request_type): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = generative_service.GenerateAnswerResponse( + answerable_probability=0.234, + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = generative_service.GenerateAnswerResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.generate_answer(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.GenerateAnswerResponse) + assert math.isclose(response.answerable_probability, 0.234, rel_tol=1e-6) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_generate_answer_rest_interceptors(null_interceptor): + transport = transports.GenerativeServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.GenerativeServiceRestInterceptor(), + ) + client = GenerativeServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.GenerativeServiceRestInterceptor, "post_generate_answer" + ) as post, mock.patch.object( + transports.GenerativeServiceRestInterceptor, "pre_generate_answer" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = generative_service.GenerateAnswerRequest.pb( + generative_service.GenerateAnswerRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = generative_service.GenerateAnswerResponse.to_json( + generative_service.GenerateAnswerResponse() + ) + req.return_value.content = return_value + + request = generative_service.GenerateAnswerRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = generative_service.GenerateAnswerResponse() + + client.generate_answer( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_stream_generate_content_rest_bad_request( + request_type=generative_service.GenerateContentRequest, +): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.stream_generate_content(request) + + +@pytest.mark.parametrize( + "request_type", + [ + generative_service.GenerateContentRequest, + dict, + ], +) +def test_stream_generate_content_rest_call_success(request_type): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = generative_service.GenerateContentResponse( + model_version="model_version_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = generative_service.GenerateContentResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + json_return_value = "[{}]".format(json_return_value) + response_value.iter_content = mock.Mock(return_value=iter(json_return_value)) + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.stream_generate_content(request) + + assert isinstance(response, Iterable) + response = next(response) + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.GenerateContentResponse) + assert response.model_version == "model_version_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_stream_generate_content_rest_interceptors(null_interceptor): + transport = transports.GenerativeServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.GenerativeServiceRestInterceptor(), + ) + client = GenerativeServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.GenerativeServiceRestInterceptor, "post_stream_generate_content" + ) as post, mock.patch.object( + transports.GenerativeServiceRestInterceptor, "pre_stream_generate_content" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = generative_service.GenerateContentRequest.pb( + generative_service.GenerateContentRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = generative_service.GenerateContentResponse.to_json( + generative_service.GenerateContentResponse() + ) + req.return_value.iter_content = mock.Mock(return_value=iter(return_value)) + + request = generative_service.GenerateContentRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = generative_service.GenerateContentResponse() + + client.stream_generate_content( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_embed_content_rest_bad_request( + request_type=generative_service.EmbedContentRequest, +): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.embed_content(request) + + +@pytest.mark.parametrize( + "request_type", + [ + generative_service.EmbedContentRequest, + dict, + ], +) +def test_embed_content_rest_call_success(request_type): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = generative_service.EmbedContentResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = generative_service.EmbedContentResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.embed_content(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.EmbedContentResponse) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_embed_content_rest_interceptors(null_interceptor): + transport = transports.GenerativeServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.GenerativeServiceRestInterceptor(), + ) + client = GenerativeServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.GenerativeServiceRestInterceptor, "post_embed_content" + ) as post, mock.patch.object( + transports.GenerativeServiceRestInterceptor, "pre_embed_content" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = generative_service.EmbedContentRequest.pb( + generative_service.EmbedContentRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = generative_service.EmbedContentResponse.to_json( + generative_service.EmbedContentResponse() + ) + req.return_value.content = return_value + + request = generative_service.EmbedContentRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = generative_service.EmbedContentResponse() + + client.embed_content( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_batch_embed_contents_rest_bad_request( + request_type=generative_service.BatchEmbedContentsRequest, +): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.batch_embed_contents(request) + + +@pytest.mark.parametrize( + "request_type", + [ + generative_service.BatchEmbedContentsRequest, + dict, + ], +) +def test_batch_embed_contents_rest_call_success(request_type): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = generative_service.BatchEmbedContentsResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = generative_service.BatchEmbedContentsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.batch_embed_contents(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.BatchEmbedContentsResponse) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_batch_embed_contents_rest_interceptors(null_interceptor): + transport = transports.GenerativeServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.GenerativeServiceRestInterceptor(), + ) + client = GenerativeServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.GenerativeServiceRestInterceptor, "post_batch_embed_contents" + ) as post, mock.patch.object( + transports.GenerativeServiceRestInterceptor, "pre_batch_embed_contents" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = generative_service.BatchEmbedContentsRequest.pb( + generative_service.BatchEmbedContentsRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = generative_service.BatchEmbedContentsResponse.to_json( + generative_service.BatchEmbedContentsResponse() + ) + req.return_value.content = return_value + + request = generative_service.BatchEmbedContentsRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = generative_service.BatchEmbedContentsResponse() + + client.batch_embed_contents( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_count_tokens_rest_bad_request( + request_type=generative_service.CountTokensRequest, +): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.count_tokens(request) + + +@pytest.mark.parametrize( + "request_type", + [ + generative_service.CountTokensRequest, + dict, + ], +) +def test_count_tokens_rest_call_success(request_type): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = generative_service.CountTokensResponse( + total_tokens=1303, + cached_content_token_count=2746, + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = generative_service.CountTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.count_tokens(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, generative_service.CountTokensResponse) + assert response.total_tokens == 1303 + assert response.cached_content_token_count == 2746 + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_count_tokens_rest_interceptors(null_interceptor): + transport = transports.GenerativeServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.GenerativeServiceRestInterceptor(), + ) + client = GenerativeServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.GenerativeServiceRestInterceptor, "post_count_tokens" + ) as post, mock.patch.object( + transports.GenerativeServiceRestInterceptor, "pre_count_tokens" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = generative_service.CountTokensRequest.pb( + generative_service.CountTokensRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = generative_service.CountTokensResponse.to_json( + generative_service.CountTokensResponse() + ) + req.return_value.content = return_value + + request = generative_service.CountTokensRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = generative_service.CountTokensResponse() + + client.count_tokens( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_bidi_generate_content_rest_error(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + with pytest.raises(NotImplementedError) as not_implemented_error: + client.bidi_generate_content({}) + assert "Method BidiGenerateContent is not available over REST transport" in str( + not_implemented_error.value + ) + + +def test_get_operation_rest_bad_request( + request_type=operations_pb2.GetOperationRequest, +): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict( + {"name": "tunedModels/sample1/operations/sample2"}, request + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.GetOperationRequest, + dict, + ], +) +def test_get_operation_rest(request_type): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1/operations/sample2"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_operation(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_list_operations_rest_bad_request( + request_type=operations_pb2.ListOperationsRequest, +): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict({"name": "tunedModels/sample1"}, request) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_operations(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.ListOperationsRequest, + dict, + ], +) +def test_list_operations_rest(request_type): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.ListOperationsResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.list_operations(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_initialize_client_w_rest(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_generate_content_empty_call_rest(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.generate_content), "__call__") as call: + client.generate_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.GenerateContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_generate_answer_empty_call_rest(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.generate_answer), "__call__") as call: + client.generate_answer(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.GenerateAnswerRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_stream_generate_content_empty_call_rest(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.stream_generate_content), "__call__" + ) as call: + client.stream_generate_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.GenerateContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_embed_content_empty_call_rest(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.embed_content), "__call__") as call: + client.embed_content(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.EmbedContentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_batch_embed_contents_empty_call_rest(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.batch_embed_contents), "__call__" + ) as call: + client.batch_embed_contents(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.BatchEmbedContentsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_count_tokens_empty_call_rest(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.count_tokens), "__call__") as call: + client.count_tokens(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = generative_service.CountTokensRequest() + + assert args[0] == request_msg + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.GenerativeServiceGrpcTransport, + ) + + +def test_generative_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.GenerativeServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_generative_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.generative_service.transports.GenerativeServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.GenerativeServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "generate_content", + "generate_answer", + "stream_generate_content", + "embed_content", + "batch_embed_contents", + "count_tokens", + "bidi_generate_content", + "get_operation", + "list_operations", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + "kind", + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_generative_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.ai.generativelanguage_v1alpha.services.generative_service.transports.GenerativeServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.GenerativeServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=(), + quota_project_id="octopus", + ) + + +def test_generative_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.ai.generativelanguage_v1alpha.services.generative_service.transports.GenerativeServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.GenerativeServiceTransport() + adc.assert_called_once() + + +def test_generative_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + GenerativeServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=(), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.GenerativeServiceGrpcTransport, + transports.GenerativeServiceGrpcAsyncIOTransport, + ], +) +def test_generative_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.GenerativeServiceGrpcTransport, + transports.GenerativeServiceGrpcAsyncIOTransport, + transports.GenerativeServiceRestTransport, + ], +) +def test_generative_service_transport_auth_gdch_credentials(transport_class): + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, "default", autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with(e) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.GenerativeServiceGrpcTransport, grpc_helpers), + (transports.GenerativeServiceGrpcAsyncIOTransport, grpc_helpers_async), + ], +) +def test_generative_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=(), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.GenerativeServiceGrpcTransport, + transports.GenerativeServiceGrpcAsyncIOTransport, + ], +) +def test_generative_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_generative_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.GenerativeServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_generative_service_host_no_port(transport_name): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_generative_service_host_with_port(transport_name): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com:8000" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_generative_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = GenerativeServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = GenerativeServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.generate_content._session + session2 = client2.transport.generate_content._session + assert session1 != session2 + session1 = client1.transport.generate_answer._session + session2 = client2.transport.generate_answer._session + assert session1 != session2 + session1 = client1.transport.stream_generate_content._session + session2 = client2.transport.stream_generate_content._session + assert session1 != session2 + session1 = client1.transport.embed_content._session + session2 = client2.transport.embed_content._session + assert session1 != session2 + session1 = client1.transport.batch_embed_contents._session + session2 = client2.transport.batch_embed_contents._session + assert session1 != session2 + session1 = client1.transport.count_tokens._session + session2 = client2.transport.count_tokens._session + assert session1 != session2 + session1 = client1.transport.bidi_generate_content._session + session2 = client2.transport.bidi_generate_content._session + assert session1 != session2 + + +def test_generative_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.GenerativeServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_generative_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.GenerativeServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.GenerativeServiceGrpcTransport, + transports.GenerativeServiceGrpcAsyncIOTransport, + ], +) +def test_generative_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.GenerativeServiceGrpcTransport, + transports.GenerativeServiceGrpcAsyncIOTransport, + ], +) +def test_generative_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_cached_content_path(): + id = "squid" + expected = "cachedContents/{id}".format( + id=id, + ) + actual = GenerativeServiceClient.cached_content_path(id) + assert expected == actual + + +def test_parse_cached_content_path(): + expected = { + "id": "clam", + } + path = GenerativeServiceClient.cached_content_path(**expected) + + # Check that the path construction is reversible. + actual = GenerativeServiceClient.parse_cached_content_path(path) + assert expected == actual + + +def test_model_path(): + model = "whelk" + expected = "models/{model}".format( + model=model, + ) + actual = GenerativeServiceClient.model_path(model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "model": "octopus", + } + path = GenerativeServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = GenerativeServiceClient.parse_model_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "oyster" + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = GenerativeServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "nudibranch", + } + path = GenerativeServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = GenerativeServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "cuttlefish" + expected = "folders/{folder}".format( + folder=folder, + ) + actual = GenerativeServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "mussel", + } + path = GenerativeServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = GenerativeServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "winkle" + expected = "organizations/{organization}".format( + organization=organization, + ) + actual = GenerativeServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nautilus", + } + path = GenerativeServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = GenerativeServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "scallop" + expected = "projects/{project}".format( + project=project, + ) + actual = GenerativeServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "abalone", + } + path = GenerativeServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = GenerativeServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "squid" + location = "clam" + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + actual = GenerativeServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "whelk", + "location": "octopus", + } + path = GenerativeServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = GenerativeServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.GenerativeServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.GenerativeServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = GenerativeServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + +def test_get_operation(transport: str = "grpc"): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + response = client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +@pytest.mark.asyncio +async def test_get_operation_async(transport: str = "grpc_asyncio"): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_get_operation_field_headers(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = operations_pb2.Operation() + + client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_operation_field_headers_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_get_operation_from_dict(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + + response = client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_operation_from_dict_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_list_operations(transport: str = "grpc"): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + response = client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +@pytest.mark.asyncio +async def test_list_operations_async(transport: str = "grpc_asyncio"): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_list_operations_field_headers(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = operations_pb2.ListOperationsResponse() + + client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_operations_field_headers_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_list_operations_from_dict(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + + response = client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_list_operations_from_dict_async(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_transport_close_grpc(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +@pytest.mark.asyncio +async def test_transport_close_grpc_asyncio(): + client = GenerativeServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close_rest(): + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + with mock.patch.object( + type(getattr(client.transport, "_session")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "rest", + "grpc", + ] + for transport in transports: + client = GenerativeServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (GenerativeServiceClient, transports.GenerativeServiceGrpcTransport), + ( + GenerativeServiceAsyncClient, + transports.GenerativeServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_model_service.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_model_service.py new file mode 100644 index 000000000000..8269673af5af --- /dev/null +++ b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_model_service.py @@ -0,0 +1,7901 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +from collections.abc import AsyncIterable, Iterable +import json +import math + +from google.api_core import api_core_version +from google.protobuf import json_format +import grpc +from grpc.experimental import aio +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule, TimestampRule +import pytest +from requests import PreparedRequest, Request, Response +from requests.sessions import Session + +try: + from google.auth.aio import credentials as ga_credentials_async + + HAS_GOOGLE_AUTH_AIO = True +except ImportError: # pragma: NO COVER + HAS_GOOGLE_AUTH_AIO = False + +from google.api_core import ( + future, + gapic_v1, + grpc_helpers, + grpc_helpers_async, + operation, + operations_v1, + path_template, +) +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import operation_async # type: ignore +from google.api_core import retry as retries +import google.auth +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.model_service import ( + ModelServiceAsyncClient, + ModelServiceClient, + pagers, + transports, +) +from google.ai.generativelanguage_v1alpha.types import tuned_model as gag_tuned_model +from google.ai.generativelanguage_v1alpha.types import model, model_service +from google.ai.generativelanguage_v1alpha.types import tuned_model + + +async def mock_async_gen(data, chunk_size=1): + for i in range(0, len(data)): # pragma: NO COVER + chunk = data[i : i + chunk_size] + yield chunk.encode("utf-8") + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# TODO: use async auth anon credentials by default once the minimum version of google-auth is upgraded. +# See related issue: https://github.com/googleapis/gapic-generator-python/issues/2107. +def async_anonymous_credentials(): + if HAS_GOOGLE_AUTH_AIO: + return ga_credentials_async.AnonymousCredentials() + return ga_credentials.AnonymousCredentials() + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +# If default endpoint template is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint template so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint_template(client): + return ( + "test.{UNIVERSE_DOMAIN}" + if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE) + else client._DEFAULT_ENDPOINT_TEMPLATE + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert ModelServiceClient._get_default_mtls_endpoint(None) is None + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + ModelServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ModelServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +def test__read_environment_variables(): + assert ModelServiceClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert ModelServiceClient._read_environment_variables() == (True, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + assert ModelServiceClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + ModelServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + assert ModelServiceClient._read_environment_variables() == ( + False, + "never", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + assert ModelServiceClient._read_environment_variables() == ( + False, + "always", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + assert ModelServiceClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + ModelServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}): + assert ModelServiceClient._read_environment_variables() == ( + False, + "auto", + "foo.com", + ) + + +def test__get_client_cert_source(): + mock_provided_cert_source = mock.Mock() + mock_default_cert_source = mock.Mock() + + assert ModelServiceClient._get_client_cert_source(None, False) is None + assert ( + ModelServiceClient._get_client_cert_source(mock_provided_cert_source, False) + is None + ) + assert ( + ModelServiceClient._get_client_cert_source(mock_provided_cert_source, True) + == mock_provided_cert_source + ) + + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", return_value=True + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_default_cert_source, + ): + assert ( + ModelServiceClient._get_client_cert_source(None, True) + is mock_default_cert_source + ) + assert ( + ModelServiceClient._get_client_cert_source( + mock_provided_cert_source, "true" + ) + is mock_provided_cert_source + ) + + +@mock.patch.object( + ModelServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(ModelServiceClient), +) +@mock.patch.object( + ModelServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(ModelServiceAsyncClient), +) +def test__get_api_endpoint(): + api_override = "foo.com" + mock_client_cert_source = mock.Mock() + default_universe = ModelServiceClient._DEFAULT_UNIVERSE + default_endpoint = ModelServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = ModelServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + assert ( + ModelServiceClient._get_api_endpoint( + api_override, mock_client_cert_source, default_universe, "always" + ) + == api_override + ) + assert ( + ModelServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "auto" + ) + == ModelServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + ModelServiceClient._get_api_endpoint(None, None, default_universe, "auto") + == default_endpoint + ) + assert ( + ModelServiceClient._get_api_endpoint(None, None, default_universe, "always") + == ModelServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + ModelServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "always" + ) + == ModelServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + ModelServiceClient._get_api_endpoint(None, None, mock_universe, "never") + == mock_endpoint + ) + assert ( + ModelServiceClient._get_api_endpoint(None, None, default_universe, "never") + == default_endpoint + ) + + with pytest.raises(MutualTLSChannelError) as excinfo: + ModelServiceClient._get_api_endpoint( + None, mock_client_cert_source, mock_universe, "auto" + ) + assert ( + str(excinfo.value) + == "mTLS is not supported in any universe other than googleapis.com." + ) + + +def test__get_universe_domain(): + client_universe_domain = "foo.com" + universe_domain_env = "bar.com" + + assert ( + ModelServiceClient._get_universe_domain( + client_universe_domain, universe_domain_env + ) + == client_universe_domain + ) + assert ( + ModelServiceClient._get_universe_domain(None, universe_domain_env) + == universe_domain_env + ) + assert ( + ModelServiceClient._get_universe_domain(None, None) + == ModelServiceClient._DEFAULT_UNIVERSE + ) + + with pytest.raises(ValueError) as excinfo: + ModelServiceClient._get_universe_domain("", None) + assert str(excinfo.value) == "Universe Domain cannot be an empty string." + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (ModelServiceClient, "grpc"), + (ModelServiceAsyncClient, "grpc_asyncio"), + (ModelServiceClient, "rest"), + ], +) +def test_model_service_client_from_service_account_info(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.ModelServiceGrpcTransport, "grpc"), + (transports.ModelServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.ModelServiceRestTransport, "rest"), + ], +) +def test_model_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (ModelServiceClient, "grpc"), + (ModelServiceAsyncClient, "grpc_asyncio"), + (ModelServiceClient, "rest"), + ], +) +def test_model_service_client_from_service_account_file(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +def test_model_service_client_get_transport_class(): + transport = ModelServiceClient.get_transport_class() + available_transports = [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceRestTransport, + ] + assert transport in available_transports + + transport = ModelServiceClient.get_transport_class("grpc") + assert transport == transports.ModelServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest"), + ], +) +@mock.patch.object( + ModelServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(ModelServiceClient), +) +@mock.patch.object( + ModelServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(ModelServiceAsyncClient), +) +def test_model_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(ModelServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "true"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc", "false"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest", "true"), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest", "false"), + ], +) +@mock.patch.object( + ModelServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(ModelServiceClient), +) +@mock.patch.object( + ModelServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(ModelServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_model_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient]) +@mock.patch.object( + ModelServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(ModelServiceClient) +) +@mock.patch.object( + ModelServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(ModelServiceAsyncClient), +) +def test_model_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + +@pytest.mark.parametrize("client_class", [ModelServiceClient, ModelServiceAsyncClient]) +@mock.patch.object( + ModelServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(ModelServiceClient), +) +@mock.patch.object( + ModelServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(ModelServiceAsyncClient), +) +def test_model_service_client_client_api_endpoint(client_class): + mock_client_cert_source = client_cert_source_callback + api_override = "foo.com" + default_universe = ModelServiceClient._DEFAULT_UNIVERSE + default_endpoint = ModelServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = ModelServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true", + # use ClientOptions.api_endpoint as the api endpoint regardless. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ): + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=api_override + ) + client = client_class( + client_options=options, + credentials=ga_credentials.AnonymousCredentials(), + ) + assert client.api_endpoint == api_override + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == default_endpoint + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always", + # use the DEFAULT_MTLS_ENDPOINT as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + + # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default), + # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist, + # and ClientOptions.universe_domain="bar.com", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint. + options = client_options.ClientOptions() + universe_exists = hasattr(options, "universe_domain") + if universe_exists: + options = client_options.ClientOptions(universe_domain=mock_universe) + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + else: + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == ( + mock_endpoint if universe_exists else default_endpoint + ) + assert client.universe_domain == ( + mock_universe if universe_exists else default_universe + ) + + # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + options = client_options.ClientOptions() + if hasattr(options, "universe_domain"): + delattr(options, "universe_domain") + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == default_endpoint + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport, "grpc"), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest"), + ], +) +def test_model_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + ModelServiceClient, + transports.ModelServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + (ModelServiceClient, transports.ModelServiceRestTransport, "rest", None), + ], +) +def test_model_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +def test_model_service_client_client_options_from_dict(): + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.model_service.transports.ModelServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = ModelServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + ModelServiceClient, + transports.ModelServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + ModelServiceAsyncClient, + transports.ModelServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_model_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=(), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + model_service.GetModelRequest, + dict, + ], +) +def test_get_model(request_type, transport: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model( + name="name_value", + base_model_id="base_model_id_value", + version="version_value", + display_name="display_name_value", + description="description_value", + input_token_limit=1838, + output_token_limit=1967, + supported_generation_methods=["supported_generation_methods_value"], + temperature=0.1198, + max_temperature=0.16190000000000002, + top_p=0.546, + top_k=541, + ) + response = client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = model_service.GetModelRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, model.Model) + assert response.name == "name_value" + assert response.base_model_id == "base_model_id_value" + assert response.version == "version_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert response.input_token_limit == 1838 + assert response.output_token_limit == 1967 + assert response.supported_generation_methods == [ + "supported_generation_methods_value" + ] + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.max_temperature, 0.16190000000000002, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + + +def test_get_model_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = model_service.GetModelRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.get_model(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.GetModelRequest( + name="name_value", + ) + + +def test_get_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_model] = mock_rpc + request = {} + client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_model_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.get_model + ] = mock_rpc + + request = {} + await client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.get_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_model_async( + transport: str = "grpc_asyncio", request_type=model_service.GetModelRequest +): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model.Model( + name="name_value", + base_model_id="base_model_id_value", + version="version_value", + display_name="display_name_value", + description="description_value", + input_token_limit=1838, + output_token_limit=1967, + supported_generation_methods=["supported_generation_methods_value"], + temperature=0.1198, + max_temperature=0.16190000000000002, + top_p=0.546, + top_k=541, + ) + ) + response = await client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = model_service.GetModelRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, model.Model) + assert response.name == "name_value" + assert response.base_model_id == "base_model_id_value" + assert response.version == "version_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert response.input_token_limit == 1838 + assert response.output_token_limit == 1967 + assert response.supported_generation_methods == [ + "supported_generation_methods_value" + ] + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.max_temperature, 0.16190000000000002, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + + +@pytest.mark.asyncio +async def test_get_model_async_from_dict(): + await test_get_model_async(request_type=dict) + + +def test_get_model_field_headers(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + call.return_value = model.Model() + client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_model_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetModelRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) + await client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_get_model_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_model( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_get_model_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model( + model_service.GetModelRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model.Model() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(model.Model()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_model( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_get_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_model( + model_service.GetModelRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + model_service.ListModelsRequest, + dict, + ], +) +def test_list_models(request_type, transport: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse( + next_page_token="next_page_token_value", + ) + response = client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = model_service.ListModelsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_models_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = model_service.ListModelsRequest( + page_token="page_token_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_models(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.ListModelsRequest( + page_token="page_token_value", + ) + + +def test_list_models_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_models in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_models] = mock_rpc + request = {} + client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_models_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_models + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.list_models + ] = mock_rpc + + request = {} + await client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.list_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_models_async( + transport: str = "grpc_asyncio", request_type=model_service.ListModelsRequest +): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse( + next_page_token="next_page_token_value", + ) + ) + response = await client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = model_service.ListModelsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelsAsyncPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_models_async_from_dict(): + await test_list_models_async(request_type=dict) + + +def test_list_models_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_models( + page_size=951, + page_token="page_token_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].page_size + mock_val = 951 + assert arg == mock_val + arg = args[0].page_token + mock_val = "page_token_value" + assert arg == mock_val + + +def test_list_models_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_models( + model_service.ListModelsRequest(), + page_size=951, + page_token="page_token_value", + ) + + +@pytest.mark.asyncio +async def test_list_models_flattened_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListModelsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_models( + page_size=951, + page_token="page_token_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].page_size + mock_val = 951 + assert arg == mock_val + arg = args[0].page_token + mock_val = "page_token_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_list_models_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_models( + model_service.ListModelsRequest(), + page_size=951, + page_token="page_token_value", + ) + + +def test_list_models_pager(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token="abc", + ), + model_service.ListModelsResponse( + models=[], + next_page_token="def", + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token="ghi", + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + + expected_metadata = () + retry = retries.Retry() + timeout = 5 + pager = client.list_models(request={}, retry=retry, timeout=timeout) + + assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, model.Model) for i in results) + + +def test_list_models_pages(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token="abc", + ), + model_service.ListModelsResponse( + models=[], + next_page_token="def", + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token="ghi", + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + pages = list(client.list_models(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_models_async_pager(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token="abc", + ), + model_service.ListModelsResponse( + models=[], + next_page_token="def", + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token="ghi", + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_models( + request={}, + ) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, model.Model) for i in responses) + + +@pytest.mark.asyncio +async def test_list_models_async_pages(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_models), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token="abc", + ), + model_service.ListModelsResponse( + models=[], + next_page_token="def", + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token="ghi", + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_models(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + model_service.GetTunedModelRequest, + dict, + ], +) +def test_get_tuned_model(request_type, transport: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = tuned_model.TunedModel( + name="name_value", + display_name="display_name_value", + description="description_value", + temperature=0.1198, + top_p=0.546, + top_k=541, + state=tuned_model.TunedModel.State.CREATING, + reader_project_numbers=[2340], + base_model="base_model_value", + ) + response = client.get_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = model_service.GetTunedModelRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, tuned_model.TunedModel) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + assert response.state == tuned_model.TunedModel.State.CREATING + assert response.reader_project_numbers == [2340] + + +def test_get_tuned_model_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = model_service.GetTunedModelRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.get_tuned_model(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.GetTunedModelRequest( + name="name_value", + ) + + +def test_get_tuned_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_tuned_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_tuned_model] = mock_rpc + request = {} + client.get_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tuned_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_tuned_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_tuned_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.get_tuned_model + ] = mock_rpc + + request = {} + await client.get_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.get_tuned_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_tuned_model_async( + transport: str = "grpc_asyncio", request_type=model_service.GetTunedModelRequest +): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tuned_model.TunedModel( + name="name_value", + display_name="display_name_value", + description="description_value", + temperature=0.1198, + top_p=0.546, + top_k=541, + state=tuned_model.TunedModel.State.CREATING, + reader_project_numbers=[2340], + ) + ) + response = await client.get_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = model_service.GetTunedModelRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, tuned_model.TunedModel) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + assert response.state == tuned_model.TunedModel.State.CREATING + assert response.reader_project_numbers == [2340] + + +@pytest.mark.asyncio +async def test_get_tuned_model_async_from_dict(): + await test_get_tuned_model_async(request_type=dict) + + +def test_get_tuned_model_field_headers(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetTunedModelRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: + call.return_value = tuned_model.TunedModel() + client.get_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_tuned_model_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.GetTunedModelRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tuned_model.TunedModel() + ) + await client.get_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_get_tuned_model_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = tuned_model.TunedModel() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_tuned_model( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_get_tuned_model_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_tuned_model( + model_service.GetTunedModelRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_tuned_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = tuned_model.TunedModel() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tuned_model.TunedModel() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_tuned_model( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_get_tuned_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_tuned_model( + model_service.GetTunedModelRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + model_service.ListTunedModelsRequest, + dict, + ], +) +def test_list_tuned_models(request_type, transport: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListTunedModelsResponse( + next_page_token="next_page_token_value", + ) + response = client.list_tuned_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = model_service.ListTunedModelsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTunedModelsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_tuned_models_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = model_service.ListTunedModelsRequest( + page_token="page_token_value", + filter="filter_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_tuned_models(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.ListTunedModelsRequest( + page_token="page_token_value", + filter="filter_value", + ) + + +def test_list_tuned_models_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_tuned_models in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tuned_models + ] = mock_rpc + request = {} + client.list_tuned_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tuned_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_tuned_models_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_tuned_models + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.list_tuned_models + ] = mock_rpc + + request = {} + await client.list_tuned_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.list_tuned_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_tuned_models_async( + transport: str = "grpc_asyncio", request_type=model_service.ListTunedModelsRequest +): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListTunedModelsResponse( + next_page_token="next_page_token_value", + ) + ) + response = await client.list_tuned_models(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = model_service.ListTunedModelsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTunedModelsAsyncPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_tuned_models_async_from_dict(): + await test_list_tuned_models_async(request_type=dict) + + +def test_list_tuned_models_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListTunedModelsResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_tuned_models( + page_size=951, + page_token="page_token_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].page_size + mock_val = 951 + assert arg == mock_val + arg = args[0].page_token + mock_val = "page_token_value" + assert arg == mock_val + + +def test_list_tuned_models_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_tuned_models( + model_service.ListTunedModelsRequest(), + page_size=951, + page_token="page_token_value", + ) + + +@pytest.mark.asyncio +async def test_list_tuned_models_flattened_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = model_service.ListTunedModelsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListTunedModelsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_tuned_models( + page_size=951, + page_token="page_token_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].page_size + mock_val = 951 + assert arg == mock_val + arg = args[0].page_token + mock_val = "page_token_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_list_tuned_models_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_tuned_models( + model_service.ListTunedModelsRequest(), + page_size=951, + page_token="page_token_value", + ) + + +def test_list_tuned_models_pager(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + next_page_token="abc", + ), + model_service.ListTunedModelsResponse( + tuned_models=[], + next_page_token="def", + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + ], + next_page_token="ghi", + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + ), + RuntimeError, + ) + + expected_metadata = () + retry = retries.Retry() + timeout = 5 + pager = client.list_tuned_models(request={}, retry=retry, timeout=timeout) + + assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, tuned_model.TunedModel) for i in results) + + +def test_list_tuned_models_pages(transport_name: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), "__call__" + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + next_page_token="abc", + ), + model_service.ListTunedModelsResponse( + tuned_models=[], + next_page_token="def", + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + ], + next_page_token="ghi", + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + ), + RuntimeError, + ) + pages = list(client.list_tuned_models(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_tuned_models_async_pager(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + next_page_token="abc", + ), + model_service.ListTunedModelsResponse( + tuned_models=[], + next_page_token="def", + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + ], + next_page_token="ghi", + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_tuned_models( + request={}, + ) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, tuned_model.TunedModel) for i in responses) + + +@pytest.mark.asyncio +async def test_list_tuned_models_async_pages(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), + "__call__", + new_callable=mock.AsyncMock, + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + next_page_token="abc", + ), + model_service.ListTunedModelsResponse( + tuned_models=[], + next_page_token="def", + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + ], + next_page_token="ghi", + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_tuned_models(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + model_service.CreateTunedModelRequest, + dict, + ], +) +def test_create_tuned_model(request_type, transport: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tuned_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/spam") + response = client.create_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = model_service.CreateTunedModelRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +def test_create_tuned_model_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = model_service.CreateTunedModelRequest( + tuned_model_id="tuned_model_id_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tuned_model), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.create_tuned_model(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.CreateTunedModelRequest( + tuned_model_id="tuned_model_id_value", + ) + + +def test_create_tuned_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tuned_model in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tuned_model + ] = mock_rpc + request = {} + client.create_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_tuned_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_tuned_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_tuned_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.create_tuned_model + ] = mock_rpc + + request = {} + await client.create_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods call wrapper_fn to build a cached + # client._transport.operations_client instance on first rpc call. + # Subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + await client.create_tuned_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_tuned_model_async( + transport: str = "grpc_asyncio", request_type=model_service.CreateTunedModelRequest +): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tuned_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + response = await client.create_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = model_service.CreateTunedModelRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, future.Future) + + +@pytest.mark.asyncio +async def test_create_tuned_model_async_from_dict(): + await test_create_tuned_model_async(request_type=dict) + + +def test_create_tuned_model_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tuned_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_tuned_model( + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + tuned_model_id="tuned_model_id_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].tuned_model + mock_val = gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ) + assert arg == mock_val + arg = args[0].tuned_model_id + mock_val = "tuned_model_id_value" + assert arg == mock_val + + +def test_create_tuned_model_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_tuned_model( + model_service.CreateTunedModelRequest(), + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + tuned_model_id="tuned_model_id_value", + ) + + +@pytest.mark.asyncio +async def test_create_tuned_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_tuned_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation(name="operations/op") + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_tuned_model( + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + tuned_model_id="tuned_model_id_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].tuned_model + mock_val = gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ) + assert arg == mock_val + arg = args[0].tuned_model_id + mock_val = "tuned_model_id_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_create_tuned_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_tuned_model( + model_service.CreateTunedModelRequest(), + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + tuned_model_id="tuned_model_id_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + model_service.UpdateTunedModelRequest, + dict, + ], +) +def test_update_tuned_model(request_type, transport: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gag_tuned_model.TunedModel( + name="name_value", + display_name="display_name_value", + description="description_value", + temperature=0.1198, + top_p=0.546, + top_k=541, + state=gag_tuned_model.TunedModel.State.CREATING, + reader_project_numbers=[2340], + base_model="base_model_value", + ) + response = client.update_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = model_service.UpdateTunedModelRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_tuned_model.TunedModel) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + assert response.state == gag_tuned_model.TunedModel.State.CREATING + assert response.reader_project_numbers == [2340] + + +def test_update_tuned_model_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = model_service.UpdateTunedModelRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.update_tuned_model(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.UpdateTunedModelRequest() + + +def test_update_tuned_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tuned_model in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tuned_model + ] = mock_rpc + request = {} + client.update_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_tuned_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_update_tuned_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_tuned_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.update_tuned_model + ] = mock_rpc + + request = {} + await client.update_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.update_tuned_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_update_tuned_model_async( + transport: str = "grpc_asyncio", request_type=model_service.UpdateTunedModelRequest +): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_tuned_model.TunedModel( + name="name_value", + display_name="display_name_value", + description="description_value", + temperature=0.1198, + top_p=0.546, + top_k=541, + state=gag_tuned_model.TunedModel.State.CREATING, + reader_project_numbers=[2340], + ) + ) + response = await client.update_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = model_service.UpdateTunedModelRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_tuned_model.TunedModel) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + assert response.state == gag_tuned_model.TunedModel.State.CREATING + assert response.reader_project_numbers == [2340] + + +@pytest.mark.asyncio +async def test_update_tuned_model_async_from_dict(): + await test_update_tuned_model_async(request_type=dict) + + +def test_update_tuned_model_field_headers(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UpdateTunedModelRequest() + + request.tuned_model.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), "__call__" + ) as call: + call.return_value = gag_tuned_model.TunedModel() + client.update_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "tuned_model.name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_tuned_model_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.UpdateTunedModelRequest() + + request.tuned_model.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_tuned_model.TunedModel() + ) + await client.update_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "tuned_model.name=name_value", + ) in kw["metadata"] + + +def test_update_tuned_model_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gag_tuned_model.TunedModel() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_tuned_model( + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].tuned_model + mock_val = gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ) + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +def test_update_tuned_model_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_tuned_model( + model_service.UpdateTunedModelRequest(), + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_tuned_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gag_tuned_model.TunedModel() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_tuned_model.TunedModel() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_tuned_model( + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].tuned_model + mock_val = gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ) + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_update_tuned_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_tuned_model( + model_service.UpdateTunedModelRequest(), + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.parametrize( + "request_type", + [ + model_service.DeleteTunedModelRequest, + dict, + ], +) +def test_delete_tuned_model(request_type, transport: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.delete_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = model_service.DeleteTunedModelRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_tuned_model_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = model_service.DeleteTunedModelRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.delete_tuned_model(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == model_service.DeleteTunedModelRequest( + name="name_value", + ) + + +def test_delete_tuned_model_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tuned_model in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tuned_model + ] = mock_rpc + request = {} + client.delete_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_tuned_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_tuned_model_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_tuned_model + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_tuned_model + ] = mock_rpc + + request = {} + await client.delete_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.delete_tuned_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_tuned_model_async( + transport: str = "grpc_asyncio", request_type=model_service.DeleteTunedModelRequest +): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.delete_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = model_service.DeleteTunedModelRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_delete_tuned_model_async_from_dict(): + await test_delete_tuned_model_async(request_type=dict) + + +def test_delete_tuned_model_field_headers(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.DeleteTunedModelRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), "__call__" + ) as call: + call.return_value = None + client.delete_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_tuned_model_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = model_service.DeleteTunedModelRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_delete_tuned_model_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_tuned_model( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_delete_tuned_model_flattened_error(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_tuned_model( + model_service.DeleteTunedModelRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_tuned_model_flattened_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_tuned_model( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_delete_tuned_model_flattened_error_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_tuned_model( + model_service.DeleteTunedModelRequest(), + name="name_value", + ) + + +def test_get_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_model] = mock_rpc + + request = {} + client.get_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_get_model_rest_required_fields(request_type=model_service.GetModelRequest): + transport_class = transports.ModelServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = model.Model() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = model.Model.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_model(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_model_rest_unset_required_fields(): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_model._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +def test_get_model_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = model.Model() + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "models/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = model.Model.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.get_model(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{name=models/*}" % client.transport._host, args[1] + ) + + +def test_get_model_rest_flattened_error(transport: str = "rest"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_model( + model_service.GetModelRequest(), + name="name_value", + ) + + +def test_list_models_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_models in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_models] = mock_rpc + + request = {} + client.list_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_list_models_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = model_service.ListModelsResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {} + + # get truthy value for each flattened field + mock_args = dict( + page_size=951, + page_token="page_token_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = model_service.ListModelsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.list_models(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/models" % client.transport._host, args[1] + ) + + +def test_list_models_rest_flattened_error(transport: str = "rest"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_models( + model_service.ListModelsRequest(), + page_size=951, + page_token="page_token_value", + ) + + +def test_list_models_rest_pager(transport: str = "rest"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + model.Model(), + ], + next_page_token="abc", + ), + model_service.ListModelsResponse( + models=[], + next_page_token="def", + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + ], + next_page_token="ghi", + ), + model_service.ListModelsResponse( + models=[ + model.Model(), + model.Model(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple(model_service.ListModelsResponse.to_json(x) for x in response) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {} + + pager = client.list_models(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, model.Model) for i in results) + + pages = list(client.list_models(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_get_tuned_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_tuned_model in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_tuned_model] = mock_rpc + + request = {} + client.get_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_tuned_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_get_tuned_model_rest_required_fields( + request_type=model_service.GetTunedModelRequest, +): + transport_class = transports.ModelServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_tuned_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_tuned_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = tuned_model.TunedModel() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = tuned_model.TunedModel.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_tuned_model(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_tuned_model_rest_unset_required_fields(): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_tuned_model._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +def test_get_tuned_model_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = tuned_model.TunedModel() + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "tunedModels/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = tuned_model.TunedModel.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.get_tuned_model(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{name=tunedModels/*}" % client.transport._host, args[1] + ) + + +def test_get_tuned_model_rest_flattened_error(transport: str = "rest"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_tuned_model( + model_service.GetTunedModelRequest(), + name="name_value", + ) + + +def test_list_tuned_models_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_tuned_models in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_tuned_models + ] = mock_rpc + + request = {} + client.list_tuned_models(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_tuned_models(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_list_tuned_models_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = model_service.ListTunedModelsResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {} + + # get truthy value for each flattened field + mock_args = dict( + page_size=951, + page_token="page_token_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = model_service.ListTunedModelsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.list_tuned_models(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/tunedModels" % client.transport._host, args[1] + ) + + +def test_list_tuned_models_rest_flattened_error(transport: str = "rest"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_tuned_models( + model_service.ListTunedModelsRequest(), + page_size=951, + page_token="page_token_value", + ) + + +def test_list_tuned_models_rest_pager(transport: str = "rest"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + next_page_token="abc", + ), + model_service.ListTunedModelsResponse( + tuned_models=[], + next_page_token="def", + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + ], + next_page_token="ghi", + ), + model_service.ListTunedModelsResponse( + tuned_models=[ + tuned_model.TunedModel(), + tuned_model.TunedModel(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + model_service.ListTunedModelsResponse.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {} + + pager = client.list_tuned_models(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, tuned_model.TunedModel) for i in results) + + pages = list(client.list_tuned_models(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_create_tuned_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.create_tuned_model in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_tuned_model + ] = mock_rpc + + request = {} + client.create_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + # Operation methods build a cached wrapper on first rpc call + # subsequent calls should use the cached wrapper + wrapper_fn.reset_mock() + + client.create_tuned_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_create_tuned_model_rest_required_fields( + request_type=model_service.CreateTunedModelRequest, +): + transport_class = transports.ModelServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_tuned_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_tuned_model._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("tuned_model_id",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.create_tuned_model(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_tuned_model_rest_unset_required_fields(): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.create_tuned_model._get_unset_required_fields({}) + assert set(unset_fields) == (set(("tunedModelId",)) & set(("tunedModel",))) + + +def test_create_tuned_model_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # get arguments that satisfy an http rule for this method + sample_request = {} + + # get truthy value for each flattened field + mock_args = dict( + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + tuned_model_id="tuned_model_id_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.create_tuned_model(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/tunedModels" % client.transport._host, args[1] + ) + + +def test_create_tuned_model_rest_flattened_error(transport: str = "rest"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_tuned_model( + model_service.CreateTunedModelRequest(), + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + tuned_model_id="tuned_model_id_value", + ) + + +def test_update_tuned_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.update_tuned_model in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_tuned_model + ] = mock_rpc + + request = {} + client.update_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_tuned_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_update_tuned_model_rest_required_fields( + request_type=model_service.UpdateTunedModelRequest, +): + transport_class = transports.ModelServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_tuned_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_tuned_model._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("update_mask",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = gag_tuned_model.TunedModel() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "patch", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = gag_tuned_model.TunedModel.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.update_tuned_model(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_update_tuned_model_rest_unset_required_fields(): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.update_tuned_model._get_unset_required_fields({}) + assert set(unset_fields) == (set(("updateMask",)) & set(("tunedModel",))) + + +def test_update_tuned_model_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = gag_tuned_model.TunedModel() + + # get arguments that satisfy an http rule for this method + sample_request = {"tuned_model": {"name": "tunedModels/sample1"}} + + # get truthy value for each flattened field + mock_args = dict( + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = gag_tuned_model.TunedModel.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.update_tuned_model(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{tuned_model.name=tunedModels/*}" % client.transport._host, + args[1], + ) + + +def test_update_tuned_model_rest_flattened_error(transport: str = "rest"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_tuned_model( + model_service.UpdateTunedModelRequest(), + tuned_model=gag_tuned_model.TunedModel( + tuned_model_source=gag_tuned_model.TunedModelSource( + tuned_model="tuned_model_value" + ) + ), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +def test_delete_tuned_model_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.delete_tuned_model in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_tuned_model + ] = mock_rpc + + request = {} + client.delete_tuned_model(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_tuned_model(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_delete_tuned_model_rest_required_fields( + request_type=model_service.DeleteTunedModelRequest, +): + transport_class = transports.ModelServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_tuned_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_tuned_model._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.delete_tuned_model(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_delete_tuned_model_rest_unset_required_fields(): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.delete_tuned_model._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +def test_delete_tuned_model_rest_flattened(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "tunedModels/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.delete_tuned_model(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{name=tunedModels/*}" % client.transport._host, args[1] + ) + + +def test_delete_tuned_model_rest_flattened_error(transport: str = "rest"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_tuned_model( + model_service.DeleteTunedModelRequest(), + name="name_value", + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = ModelServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = ModelServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.ModelServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.ModelServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, + transports.ModelServiceRestTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_kind_grpc(): + transport = ModelServiceClient.get_transport_class("grpc")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "grpc" + + +def test_initialize_client_w_grpc(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_model_empty_call_grpc(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + call.return_value = model.Model() + client.get_model(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.GetModelRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_models_empty_call_grpc(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + call.return_value = model_service.ListModelsResponse() + client.list_models(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.ListModelsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_tuned_model_empty_call_grpc(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: + call.return_value = tuned_model.TunedModel() + client.get_tuned_model(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.GetTunedModelRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_tuned_models_empty_call_grpc(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), "__call__" + ) as call: + call.return_value = model_service.ListTunedModelsResponse() + client.list_tuned_models(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.ListTunedModelsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_create_tuned_model_empty_call_grpc(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.create_tuned_model), "__call__" + ) as call: + call.return_value = operations_pb2.Operation(name="operations/op") + client.create_tuned_model(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.CreateTunedModelRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_update_tuned_model_empty_call_grpc(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), "__call__" + ) as call: + call.return_value = gag_tuned_model.TunedModel() + client.update_tuned_model(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.UpdateTunedModelRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_delete_tuned_model_empty_call_grpc(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), "__call__" + ) as call: + call.return_value = None + client.delete_tuned_model(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.DeleteTunedModelRequest() + + assert args[0] == request_msg + + +def test_transport_kind_grpc_asyncio(): + transport = ModelServiceAsyncClient.get_transport_class("grpc_asyncio")( + credentials=async_anonymous_credentials() + ) + assert transport.kind == "grpc_asyncio" + + +def test_initialize_client_w_grpc_asyncio(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_get_model_empty_call_grpc_asyncio(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model.Model( + name="name_value", + base_model_id="base_model_id_value", + version="version_value", + display_name="display_name_value", + description="description_value", + input_token_limit=1838, + output_token_limit=1967, + supported_generation_methods=["supported_generation_methods_value"], + temperature=0.1198, + max_temperature=0.16190000000000002, + top_p=0.546, + top_k=541, + ) + ) + await client.get_model(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.GetModelRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_list_models_empty_call_grpc_asyncio(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListModelsResponse( + next_page_token="next_page_token_value", + ) + ) + await client.list_models(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.ListModelsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_get_tuned_model_empty_call_grpc_asyncio(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + tuned_model.TunedModel( + name="name_value", + display_name="display_name_value", + description="description_value", + temperature=0.1198, + top_p=0.546, + top_k=541, + state=tuned_model.TunedModel.State.CREATING, + reader_project_numbers=[2340], + ) + ) + await client.get_tuned_model(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.GetTunedModelRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_list_tuned_models_empty_call_grpc_asyncio(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + model_service.ListTunedModelsResponse( + next_page_token="next_page_token_value", + ) + ) + await client.list_tuned_models(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.ListTunedModelsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_create_tuned_model_empty_call_grpc_asyncio(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.create_tuned_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation(name="operations/spam") + ) + await client.create_tuned_model(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.CreateTunedModelRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_update_tuned_model_empty_call_grpc_asyncio(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_tuned_model.TunedModel( + name="name_value", + display_name="display_name_value", + description="description_value", + temperature=0.1198, + top_p=0.546, + top_k=541, + state=gag_tuned_model.TunedModel.State.CREATING, + reader_project_numbers=[2340], + ) + ) + await client.update_tuned_model(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.UpdateTunedModelRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_delete_tuned_model_empty_call_grpc_asyncio(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_tuned_model(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.DeleteTunedModelRequest() + + assert args[0] == request_msg + + +def test_transport_kind_rest(): + transport = ModelServiceClient.get_transport_class("rest")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "rest" + + +def test_get_model_rest_bad_request(request_type=model_service.GetModelRequest): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_model(request) + + +@pytest.mark.parametrize( + "request_type", + [ + model_service.GetModelRequest, + dict, + ], +) +def test_get_model_rest_call_success(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = model.Model( + name="name_value", + base_model_id="base_model_id_value", + version="version_value", + display_name="display_name_value", + description="description_value", + input_token_limit=1838, + output_token_limit=1967, + supported_generation_methods=["supported_generation_methods_value"], + temperature=0.1198, + max_temperature=0.16190000000000002, + top_p=0.546, + top_k=541, + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = model.Model.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.get_model(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, model.Model) + assert response.name == "name_value" + assert response.base_model_id == "base_model_id_value" + assert response.version == "version_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert response.input_token_limit == 1838 + assert response.output_token_limit == 1967 + assert response.supported_generation_methods == [ + "supported_generation_methods_value" + ] + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.max_temperature, 0.16190000000000002, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_model_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.ModelServiceRestInterceptor, "post_get_model" + ) as post, mock.patch.object( + transports.ModelServiceRestInterceptor, "pre_get_model" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = model_service.GetModelRequest.pb(model_service.GetModelRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = model.Model.to_json(model.Model()) + req.return_value.content = return_value + + request = model_service.GetModelRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = model.Model() + + client.get_model( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_models_rest_bad_request(request_type=model_service.ListModelsRequest): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_models(request) + + +@pytest.mark.parametrize( + "request_type", + [ + model_service.ListModelsRequest, + dict, + ], +) +def test_list_models_rest_call_success(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = model_service.ListModelsResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = model_service.ListModelsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.list_models(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListModelsPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_models_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.ModelServiceRestInterceptor, "post_list_models" + ) as post, mock.patch.object( + transports.ModelServiceRestInterceptor, "pre_list_models" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = model_service.ListModelsRequest.pb( + model_service.ListModelsRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = model_service.ListModelsResponse.to_json( + model_service.ListModelsResponse() + ) + req.return_value.content = return_value + + request = model_service.ListModelsRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = model_service.ListModelsResponse() + + client.list_models( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_tuned_model_rest_bad_request( + request_type=model_service.GetTunedModelRequest, +): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "tunedModels/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_tuned_model(request) + + +@pytest.mark.parametrize( + "request_type", + [ + model_service.GetTunedModelRequest, + dict, + ], +) +def test_get_tuned_model_rest_call_success(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "tunedModels/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = tuned_model.TunedModel( + name="name_value", + display_name="display_name_value", + description="description_value", + temperature=0.1198, + top_p=0.546, + top_k=541, + state=tuned_model.TunedModel.State.CREATING, + reader_project_numbers=[2340], + base_model="base_model_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = tuned_model.TunedModel.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.get_tuned_model(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, tuned_model.TunedModel) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + assert response.state == tuned_model.TunedModel.State.CREATING + assert response.reader_project_numbers == [2340] + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_tuned_model_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.ModelServiceRestInterceptor, "post_get_tuned_model" + ) as post, mock.patch.object( + transports.ModelServiceRestInterceptor, "pre_get_tuned_model" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = model_service.GetTunedModelRequest.pb( + model_service.GetTunedModelRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = tuned_model.TunedModel.to_json(tuned_model.TunedModel()) + req.return_value.content = return_value + + request = model_service.GetTunedModelRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = tuned_model.TunedModel() + + client.get_tuned_model( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_tuned_models_rest_bad_request( + request_type=model_service.ListTunedModelsRequest, +): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_tuned_models(request) + + +@pytest.mark.parametrize( + "request_type", + [ + model_service.ListTunedModelsRequest, + dict, + ], +) +def test_list_tuned_models_rest_call_success(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = model_service.ListTunedModelsResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = model_service.ListTunedModelsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.list_tuned_models(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListTunedModelsPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_tuned_models_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.ModelServiceRestInterceptor, "post_list_tuned_models" + ) as post, mock.patch.object( + transports.ModelServiceRestInterceptor, "pre_list_tuned_models" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = model_service.ListTunedModelsRequest.pb( + model_service.ListTunedModelsRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = model_service.ListTunedModelsResponse.to_json( + model_service.ListTunedModelsResponse() + ) + req.return_value.content = return_value + + request = model_service.ListTunedModelsRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = model_service.ListTunedModelsResponse() + + client.list_tuned_models( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_create_tuned_model_rest_bad_request( + request_type=model_service.CreateTunedModelRequest, +): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.create_tuned_model(request) + + +@pytest.mark.parametrize( + "request_type", + [ + model_service.CreateTunedModelRequest, + dict, + ], +) +def test_create_tuned_model_rest_call_success(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {} + request_init["tuned_model"] = { + "tuned_model_source": { + "tuned_model": "tuned_model_value", + "base_model": "base_model_value", + }, + "base_model": "base_model_value", + "name": "name_value", + "display_name": "display_name_value", + "description": "description_value", + "temperature": 0.1198, + "top_p": 0.546, + "top_k": 541, + "state": 1, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "tuning_task": { + "start_time": {}, + "complete_time": {}, + "snapshots": [ + {"step": 444, "epoch": 527, "mean_loss": 0.961, "compute_time": {}} + ], + "training_data": { + "examples": { + "examples": [ + {"text_input": "text_input_value", "output": "output_value"} + ], + "multiturn_examples": [ + { + "system_instruction": { + "parts": [{"text": "text_value"}], + "role": "role_value", + }, + "contents": {}, + } + ], + } + }, + "hyperparameters": { + "learning_rate": 0.1371, + "learning_rate_multiplier": 0.2561, + "epoch_count": 1175, + "batch_size": 1052, + }, + }, + "reader_project_numbers": [2341, 2342], + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = model_service.CreateTunedModelRequest.meta.fields["tuned_model"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["tuned_model"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["tuned_model"][field])): + del request_init["tuned_model"][field][i][subfield] + else: + del request_init["tuned_model"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation(name="operations/spam") + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.create_tuned_model(request) + + # Establish that the response is the type that we expect. + json_return_value = json_format.MessageToJson(return_value) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_tuned_model_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + operation.Operation, "_set_result_from_operation" + ), mock.patch.object( + transports.ModelServiceRestInterceptor, "post_create_tuned_model" + ) as post, mock.patch.object( + transports.ModelServiceRestInterceptor, "pre_create_tuned_model" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = model_service.CreateTunedModelRequest.pb( + model_service.CreateTunedModelRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = json_format.MessageToJson(operations_pb2.Operation()) + req.return_value.content = return_value + + request = model_service.CreateTunedModelRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = operations_pb2.Operation() + + client.create_tuned_model( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_update_tuned_model_rest_bad_request( + request_type=model_service.UpdateTunedModelRequest, +): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"tuned_model": {"name": "tunedModels/sample1"}} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.update_tuned_model(request) + + +@pytest.mark.parametrize( + "request_type", + [ + model_service.UpdateTunedModelRequest, + dict, + ], +) +def test_update_tuned_model_rest_call_success(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"tuned_model": {"name": "tunedModels/sample1"}} + request_init["tuned_model"] = { + "tuned_model_source": { + "tuned_model": "tuned_model_value", + "base_model": "base_model_value", + }, + "base_model": "base_model_value", + "name": "tunedModels/sample1", + "display_name": "display_name_value", + "description": "description_value", + "temperature": 0.1198, + "top_p": 0.546, + "top_k": 541, + "state": 1, + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "tuning_task": { + "start_time": {}, + "complete_time": {}, + "snapshots": [ + {"step": 444, "epoch": 527, "mean_loss": 0.961, "compute_time": {}} + ], + "training_data": { + "examples": { + "examples": [ + {"text_input": "text_input_value", "output": "output_value"} + ], + "multiturn_examples": [ + { + "system_instruction": { + "parts": [{"text": "text_value"}], + "role": "role_value", + }, + "contents": {}, + } + ], + } + }, + "hyperparameters": { + "learning_rate": 0.1371, + "learning_rate_multiplier": 0.2561, + "epoch_count": 1175, + "batch_size": 1052, + }, + }, + "reader_project_numbers": [2341, 2342], + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = model_service.UpdateTunedModelRequest.meta.fields["tuned_model"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["tuned_model"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["tuned_model"][field])): + del request_init["tuned_model"][field][i][subfield] + else: + del request_init["tuned_model"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = gag_tuned_model.TunedModel( + name="name_value", + display_name="display_name_value", + description="description_value", + temperature=0.1198, + top_p=0.546, + top_k=541, + state=gag_tuned_model.TunedModel.State.CREATING, + reader_project_numbers=[2340], + base_model="base_model_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = gag_tuned_model.TunedModel.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.update_tuned_model(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_tuned_model.TunedModel) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + assert response.description == "description_value" + assert math.isclose(response.temperature, 0.1198, rel_tol=1e-6) + assert math.isclose(response.top_p, 0.546, rel_tol=1e-6) + assert response.top_k == 541 + assert response.state == gag_tuned_model.TunedModel.State.CREATING + assert response.reader_project_numbers == [2340] + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_update_tuned_model_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.ModelServiceRestInterceptor, "post_update_tuned_model" + ) as post, mock.patch.object( + transports.ModelServiceRestInterceptor, "pre_update_tuned_model" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = model_service.UpdateTunedModelRequest.pb( + model_service.UpdateTunedModelRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = gag_tuned_model.TunedModel.to_json(gag_tuned_model.TunedModel()) + req.return_value.content = return_value + + request = model_service.UpdateTunedModelRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = gag_tuned_model.TunedModel() + + client.update_tuned_model( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_delete_tuned_model_rest_bad_request( + request_type=model_service.DeleteTunedModelRequest, +): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "tunedModels/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.delete_tuned_model(request) + + +@pytest.mark.parametrize( + "request_type", + [ + model_service.DeleteTunedModelRequest, + dict, + ], +) +def test_delete_tuned_model_rest_call_success(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "tunedModels/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = "" + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.delete_tuned_model(request) + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_delete_tuned_model_rest_interceptors(null_interceptor): + transport = transports.ModelServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.ModelServiceRestInterceptor(), + ) + client = ModelServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.ModelServiceRestInterceptor, "pre_delete_tuned_model" + ) as pre: + pre.assert_not_called() + pb_message = model_service.DeleteTunedModelRequest.pb( + model_service.DeleteTunedModelRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + request = model_service.DeleteTunedModelRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.delete_tuned_model( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + + +def test_get_operation_rest_bad_request( + request_type=operations_pb2.GetOperationRequest, +): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict( + {"name": "tunedModels/sample1/operations/sample2"}, request + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.GetOperationRequest, + dict, + ], +) +def test_get_operation_rest(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1/operations/sample2"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_operation(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_list_operations_rest_bad_request( + request_type=operations_pb2.ListOperationsRequest, +): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict({"name": "tunedModels/sample1"}, request) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_operations(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.ListOperationsRequest, + dict, + ], +) +def test_list_operations_rest(request_type): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.ListOperationsResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.list_operations(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_initialize_client_w_rest(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_model_empty_call_rest(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_model), "__call__") as call: + client.get_model(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.GetModelRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_models_empty_call_rest(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_models), "__call__") as call: + client.list_models(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.ListModelsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_tuned_model_empty_call_rest(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_tuned_model), "__call__") as call: + client.get_tuned_model(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.GetTunedModelRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_tuned_models_empty_call_rest(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.list_tuned_models), "__call__" + ) as call: + client.list_tuned_models(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.ListTunedModelsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_create_tuned_model_empty_call_rest(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.create_tuned_model), "__call__" + ) as call: + client.create_tuned_model(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.CreateTunedModelRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_update_tuned_model_empty_call_rest(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.update_tuned_model), "__call__" + ) as call: + client.update_tuned_model(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.UpdateTunedModelRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_delete_tuned_model_empty_call_rest(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.delete_tuned_model), "__call__" + ) as call: + client.delete_tuned_model(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = model_service.DeleteTunedModelRequest() + + assert args[0] == request_msg + + +def test_model_service_rest_lro_client(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + transport = client.transport + + # Ensure that we have an api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.AbstractOperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.ModelServiceGrpcTransport, + ) + + +def test_model_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.ModelServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_model_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.model_service.transports.ModelServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.ModelServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "get_model", + "list_models", + "get_tuned_model", + "list_tuned_models", + "create_tuned_model", + "update_tuned_model", + "delete_tuned_model", + "get_operation", + "list_operations", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Additionally, the LRO client (a property) should + # also raise NotImplementedError + with pytest.raises(NotImplementedError): + transport.operations_client + + # Catch all for all remaining methods and properties + remainder = [ + "kind", + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_model_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.ai.generativelanguage_v1alpha.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.ModelServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=(), + quota_project_id="octopus", + ) + + +def test_model_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.ai.generativelanguage_v1alpha.services.model_service.transports.ModelServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.ModelServiceTransport() + adc.assert_called_once() + + +def test_model_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + ModelServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=(), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, + ], +) +def test_model_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.ModelServiceGrpcTransport, + transports.ModelServiceGrpcAsyncIOTransport, + transports.ModelServiceRestTransport, + ], +) +def test_model_service_transport_auth_gdch_credentials(transport_class): + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, "default", autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with(e) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.ModelServiceGrpcTransport, grpc_helpers), + (transports.ModelServiceGrpcAsyncIOTransport, grpc_helpers_async), + ], +) +def test_model_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=(), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_model_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.ModelServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_model_service_host_no_port(transport_name): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_model_service_host_with_port(transport_name): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com:8000" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_model_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = ModelServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = ModelServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.get_model._session + session2 = client2.transport.get_model._session + assert session1 != session2 + session1 = client1.transport.list_models._session + session2 = client2.transport.list_models._session + assert session1 != session2 + session1 = client1.transport.get_tuned_model._session + session2 = client2.transport.get_tuned_model._session + assert session1 != session2 + session1 = client1.transport.list_tuned_models._session + session2 = client2.transport.list_tuned_models._session + assert session1 != session2 + session1 = client1.transport.create_tuned_model._session + session2 = client2.transport.create_tuned_model._session + assert session1 != session2 + session1 = client1.transport.update_tuned_model._session + session2 = client2.transport.update_tuned_model._session + assert session1 != session2 + session1 = client1.transport.delete_tuned_model._session + session2 = client2.transport.delete_tuned_model._session + assert session1 != session2 + + +def test_model_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.ModelServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_model_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.ModelServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [transports.ModelServiceGrpcTransport, transports.ModelServiceGrpcAsyncIOTransport], +) +def test_model_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_model_service_grpc_lro_client(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_model_service_grpc_lro_async_client(): + client = ModelServiceAsyncClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc_asyncio", + ) + transport = client.transport + + # Ensure that we have a api-core operations client. + assert isinstance( + transport.operations_client, + operations_v1.OperationsAsyncClient, + ) + + # Ensure that subsequent calls to the property send the exact same object. + assert transport.operations_client is transport.operations_client + + +def test_model_path(): + model = "squid" + expected = "models/{model}".format( + model=model, + ) + actual = ModelServiceClient.model_path(model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "model": "clam", + } + path = ModelServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_model_path(path) + assert expected == actual + + +def test_tuned_model_path(): + tuned_model = "whelk" + expected = "tunedModels/{tuned_model}".format( + tuned_model=tuned_model, + ) + actual = ModelServiceClient.tuned_model_path(tuned_model) + assert expected == actual + + +def test_parse_tuned_model_path(): + expected = { + "tuned_model": "octopus", + } + path = ModelServiceClient.tuned_model_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_tuned_model_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "oyster" + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = ModelServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "nudibranch", + } + path = ModelServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "cuttlefish" + expected = "folders/{folder}".format( + folder=folder, + ) + actual = ModelServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "mussel", + } + path = ModelServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "winkle" + expected = "organizations/{organization}".format( + organization=organization, + ) + actual = ModelServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nautilus", + } + path = ModelServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "scallop" + expected = "projects/{project}".format( + project=project, + ) + actual = ModelServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "abalone", + } + path = ModelServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "squid" + location = "clam" + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + actual = ModelServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "whelk", + "location": "octopus", + } + path = ModelServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = ModelServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.ModelServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = ModelServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + +def test_get_operation(transport: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + response = client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +@pytest.mark.asyncio +async def test_get_operation_async(transport: str = "grpc_asyncio"): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_get_operation_field_headers(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = operations_pb2.Operation() + + client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_operation_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_get_operation_from_dict(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + + response = client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_operation_from_dict_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_list_operations(transport: str = "grpc"): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + response = client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +@pytest.mark.asyncio +async def test_list_operations_async(transport: str = "grpc_asyncio"): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_list_operations_field_headers(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = operations_pb2.ListOperationsResponse() + + client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_operations_field_headers_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_list_operations_from_dict(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + + response = client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_list_operations_from_dict_async(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_transport_close_grpc(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +@pytest.mark.asyncio +async def test_transport_close_grpc_asyncio(): + client = ModelServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close_rest(): + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + with mock.patch.object( + type(getattr(client.transport, "_session")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "rest", + "grpc", + ] + for transport in transports: + client = ModelServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (ModelServiceClient, transports.ModelServiceGrpcTransport), + (ModelServiceAsyncClient, transports.ModelServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_permission_service.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_permission_service.py new file mode 100644 index 000000000000..6b6a5579e5e7 --- /dev/null +++ b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_permission_service.py @@ -0,0 +1,6934 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +from collections.abc import AsyncIterable, Iterable +import json +import math + +from google.api_core import api_core_version +from google.protobuf import json_format +import grpc +from grpc.experimental import aio +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule, TimestampRule +import pytest +from requests import PreparedRequest, Request, Response +from requests.sessions import Session + +try: + from google.auth.aio import credentials as ga_credentials_async + + HAS_GOOGLE_AUTH_AIO = True +except ImportError: # pragma: NO COVER + HAS_GOOGLE_AUTH_AIO = False + +from google.api_core import gapic_v1, grpc_helpers, grpc_helpers_async, path_template +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +import google.auth +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.permission_service import ( + PermissionServiceAsyncClient, + PermissionServiceClient, + pagers, + transports, +) +from google.ai.generativelanguage_v1alpha.types import permission as gag_permission +from google.ai.generativelanguage_v1alpha.types import permission +from google.ai.generativelanguage_v1alpha.types import permission_service + + +async def mock_async_gen(data, chunk_size=1): + for i in range(0, len(data)): # pragma: NO COVER + chunk = data[i : i + chunk_size] + yield chunk.encode("utf-8") + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# TODO: use async auth anon credentials by default once the minimum version of google-auth is upgraded. +# See related issue: https://github.com/googleapis/gapic-generator-python/issues/2107. +def async_anonymous_credentials(): + if HAS_GOOGLE_AUTH_AIO: + return ga_credentials_async.AnonymousCredentials() + return ga_credentials.AnonymousCredentials() + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +# If default endpoint template is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint template so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint_template(client): + return ( + "test.{UNIVERSE_DOMAIN}" + if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE) + else client._DEFAULT_ENDPOINT_TEMPLATE + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert PermissionServiceClient._get_default_mtls_endpoint(None) is None + assert ( + PermissionServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + PermissionServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + PermissionServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PermissionServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PermissionServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) + + +def test__read_environment_variables(): + assert PermissionServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert PermissionServiceClient._read_environment_variables() == ( + True, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + assert PermissionServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + PermissionServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + assert PermissionServiceClient._read_environment_variables() == ( + False, + "never", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + assert PermissionServiceClient._read_environment_variables() == ( + False, + "always", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + assert PermissionServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + PermissionServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}): + assert PermissionServiceClient._read_environment_variables() == ( + False, + "auto", + "foo.com", + ) + + +def test__get_client_cert_source(): + mock_provided_cert_source = mock.Mock() + mock_default_cert_source = mock.Mock() + + assert PermissionServiceClient._get_client_cert_source(None, False) is None + assert ( + PermissionServiceClient._get_client_cert_source( + mock_provided_cert_source, False + ) + is None + ) + assert ( + PermissionServiceClient._get_client_cert_source(mock_provided_cert_source, True) + == mock_provided_cert_source + ) + + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", return_value=True + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_default_cert_source, + ): + assert ( + PermissionServiceClient._get_client_cert_source(None, True) + is mock_default_cert_source + ) + assert ( + PermissionServiceClient._get_client_cert_source( + mock_provided_cert_source, "true" + ) + is mock_provided_cert_source + ) + + +@mock.patch.object( + PermissionServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PermissionServiceClient), +) +@mock.patch.object( + PermissionServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PermissionServiceAsyncClient), +) +def test__get_api_endpoint(): + api_override = "foo.com" + mock_client_cert_source = mock.Mock() + default_universe = PermissionServiceClient._DEFAULT_UNIVERSE + default_endpoint = PermissionServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = PermissionServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + assert ( + PermissionServiceClient._get_api_endpoint( + api_override, mock_client_cert_source, default_universe, "always" + ) + == api_override + ) + assert ( + PermissionServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "auto" + ) + == PermissionServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + PermissionServiceClient._get_api_endpoint(None, None, default_universe, "auto") + == default_endpoint + ) + assert ( + PermissionServiceClient._get_api_endpoint( + None, None, default_universe, "always" + ) + == PermissionServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + PermissionServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "always" + ) + == PermissionServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + PermissionServiceClient._get_api_endpoint(None, None, mock_universe, "never") + == mock_endpoint + ) + assert ( + PermissionServiceClient._get_api_endpoint(None, None, default_universe, "never") + == default_endpoint + ) + + with pytest.raises(MutualTLSChannelError) as excinfo: + PermissionServiceClient._get_api_endpoint( + None, mock_client_cert_source, mock_universe, "auto" + ) + assert ( + str(excinfo.value) + == "mTLS is not supported in any universe other than googleapis.com." + ) + + +def test__get_universe_domain(): + client_universe_domain = "foo.com" + universe_domain_env = "bar.com" + + assert ( + PermissionServiceClient._get_universe_domain( + client_universe_domain, universe_domain_env + ) + == client_universe_domain + ) + assert ( + PermissionServiceClient._get_universe_domain(None, universe_domain_env) + == universe_domain_env + ) + assert ( + PermissionServiceClient._get_universe_domain(None, None) + == PermissionServiceClient._DEFAULT_UNIVERSE + ) + + with pytest.raises(ValueError) as excinfo: + PermissionServiceClient._get_universe_domain("", None) + assert str(excinfo.value) == "Universe Domain cannot be an empty string." + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (PermissionServiceClient, "grpc"), + (PermissionServiceAsyncClient, "grpc_asyncio"), + (PermissionServiceClient, "rest"), + ], +) +def test_permission_service_client_from_service_account_info( + client_class, transport_name +): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.PermissionServiceGrpcTransport, "grpc"), + (transports.PermissionServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.PermissionServiceRestTransport, "rest"), + ], +) +def test_permission_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (PermissionServiceClient, "grpc"), + (PermissionServiceAsyncClient, "grpc_asyncio"), + (PermissionServiceClient, "rest"), + ], +) +def test_permission_service_client_from_service_account_file( + client_class, transport_name +): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +def test_permission_service_client_get_transport_class(): + transport = PermissionServiceClient.get_transport_class() + available_transports = [ + transports.PermissionServiceGrpcTransport, + transports.PermissionServiceRestTransport, + ] + assert transport in available_transports + + transport = PermissionServiceClient.get_transport_class("grpc") + assert transport == transports.PermissionServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc"), + ( + PermissionServiceAsyncClient, + transports.PermissionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (PermissionServiceClient, transports.PermissionServiceRestTransport, "rest"), + ], +) +@mock.patch.object( + PermissionServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PermissionServiceClient), +) +@mock.patch.object( + PermissionServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PermissionServiceAsyncClient), +) +def test_permission_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(PermissionServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(PermissionServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + PermissionServiceClient, + transports.PermissionServiceGrpcTransport, + "grpc", + "true", + ), + ( + PermissionServiceAsyncClient, + transports.PermissionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + PermissionServiceClient, + transports.PermissionServiceGrpcTransport, + "grpc", + "false", + ), + ( + PermissionServiceAsyncClient, + transports.PermissionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ( + PermissionServiceClient, + transports.PermissionServiceRestTransport, + "rest", + "true", + ), + ( + PermissionServiceClient, + transports.PermissionServiceRestTransport, + "rest", + "false", + ), + ], +) +@mock.patch.object( + PermissionServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PermissionServiceClient), +) +@mock.patch.object( + PermissionServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PermissionServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_permission_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class", [PermissionServiceClient, PermissionServiceAsyncClient] +) +@mock.patch.object( + PermissionServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PermissionServiceClient), +) +@mock.patch.object( + PermissionServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PermissionServiceAsyncClient), +) +def test_permission_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + +@pytest.mark.parametrize( + "client_class", [PermissionServiceClient, PermissionServiceAsyncClient] +) +@mock.patch.object( + PermissionServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PermissionServiceClient), +) +@mock.patch.object( + PermissionServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PermissionServiceAsyncClient), +) +def test_permission_service_client_client_api_endpoint(client_class): + mock_client_cert_source = client_cert_source_callback + api_override = "foo.com" + default_universe = PermissionServiceClient._DEFAULT_UNIVERSE + default_endpoint = PermissionServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = PermissionServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true", + # use ClientOptions.api_endpoint as the api endpoint regardless. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ): + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=api_override + ) + client = client_class( + client_options=options, + credentials=ga_credentials.AnonymousCredentials(), + ) + assert client.api_endpoint == api_override + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == default_endpoint + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always", + # use the DEFAULT_MTLS_ENDPOINT as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + + # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default), + # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist, + # and ClientOptions.universe_domain="bar.com", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint. + options = client_options.ClientOptions() + universe_exists = hasattr(options, "universe_domain") + if universe_exists: + options = client_options.ClientOptions(universe_domain=mock_universe) + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + else: + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == ( + mock_endpoint if universe_exists else default_endpoint + ) + assert client.universe_domain == ( + mock_universe if universe_exists else default_universe + ) + + # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + options = client_options.ClientOptions() + if hasattr(options, "universe_domain"): + delattr(options, "universe_domain") + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == default_endpoint + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PermissionServiceClient, transports.PermissionServiceGrpcTransport, "grpc"), + ( + PermissionServiceAsyncClient, + transports.PermissionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (PermissionServiceClient, transports.PermissionServiceRestTransport, "rest"), + ], +) +def test_permission_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + PermissionServiceClient, + transports.PermissionServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + PermissionServiceAsyncClient, + transports.PermissionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ( + PermissionServiceClient, + transports.PermissionServiceRestTransport, + "rest", + None, + ), + ], +) +def test_permission_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +def test_permission_service_client_client_options_from_dict(): + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.permission_service.transports.PermissionServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = PermissionServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + PermissionServiceClient, + transports.PermissionServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + PermissionServiceAsyncClient, + transports.PermissionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_permission_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=(), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + permission_service.CreatePermissionRequest, + dict, + ], +) +def test_create_permission(request_type, transport: str = "grpc"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gag_permission.Permission( + name="name_value", + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=gag_permission.Permission.Role.OWNER, + ) + response = client.create_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = permission_service.CreatePermissionRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_permission.Permission) + assert response.name == "name_value" + assert response.grantee_type == gag_permission.Permission.GranteeType.USER + assert response.email_address == "email_address_value" + assert response.role == gag_permission.Permission.Role.OWNER + + +def test_create_permission_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = permission_service.CreatePermissionRequest( + parent="parent_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.create_permission(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.CreatePermissionRequest( + parent="parent_value", + ) + + +def test_create_permission_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_permission in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_permission + ] = mock_rpc + request = {} + client.create_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_permission(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_permission_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_permission + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.create_permission + ] = mock_rpc + + request = {} + await client.create_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.create_permission(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_permission_async( + transport: str = "grpc_asyncio", + request_type=permission_service.CreatePermissionRequest, +): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_permission.Permission( + name="name_value", + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=gag_permission.Permission.Role.OWNER, + ) + ) + response = await client.create_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = permission_service.CreatePermissionRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_permission.Permission) + assert response.name == "name_value" + assert response.grantee_type == gag_permission.Permission.GranteeType.USER + assert response.email_address == "email_address_value" + assert response.role == gag_permission.Permission.Role.OWNER + + +@pytest.mark.asyncio +async def test_create_permission_async_from_dict(): + await test_create_permission_async(request_type=dict) + + +def test_create_permission_field_headers(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.CreatePermissionRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), "__call__" + ) as call: + call.return_value = gag_permission.Permission() + client.create_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_permission_field_headers_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.CreatePermissionRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_permission.Permission() + ) + await client.create_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +def test_create_permission_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gag_permission.Permission() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_permission( + parent="parent_value", + permission=gag_permission.Permission(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + arg = args[0].permission + mock_val = gag_permission.Permission(name="name_value") + assert arg == mock_val + + +def test_create_permission_flattened_error(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_permission( + permission_service.CreatePermissionRequest(), + parent="parent_value", + permission=gag_permission.Permission(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_permission_flattened_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gag_permission.Permission() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_permission.Permission() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_permission( + parent="parent_value", + permission=gag_permission.Permission(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + arg = args[0].permission + mock_val = gag_permission.Permission(name="name_value") + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_create_permission_flattened_error_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_permission( + permission_service.CreatePermissionRequest(), + parent="parent_value", + permission=gag_permission.Permission(name="name_value"), + ) + + +@pytest.mark.parametrize( + "request_type", + [ + permission_service.GetPermissionRequest, + dict, + ], +) +def test_get_permission(request_type, transport: str = "grpc"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = permission.Permission( + name="name_value", + grantee_type=permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=permission.Permission.Role.OWNER, + ) + response = client.get_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = permission_service.GetPermissionRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, permission.Permission) + assert response.name == "name_value" + assert response.grantee_type == permission.Permission.GranteeType.USER + assert response.email_address == "email_address_value" + assert response.role == permission.Permission.Role.OWNER + + +def test_get_permission_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = permission_service.GetPermissionRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.get_permission(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.GetPermissionRequest( + name="name_value", + ) + + +def test_get_permission_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_permission in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_permission] = mock_rpc + request = {} + client.get_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_permission(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_permission_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_permission + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.get_permission + ] = mock_rpc + + request = {} + await client.get_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.get_permission(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_permission_async( + transport: str = "grpc_asyncio", + request_type=permission_service.GetPermissionRequest, +): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission.Permission( + name="name_value", + grantee_type=permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=permission.Permission.Role.OWNER, + ) + ) + response = await client.get_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = permission_service.GetPermissionRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, permission.Permission) + assert response.name == "name_value" + assert response.grantee_type == permission.Permission.GranteeType.USER + assert response.email_address == "email_address_value" + assert response.role == permission.Permission.Role.OWNER + + +@pytest.mark.asyncio +async def test_get_permission_async_from_dict(): + await test_get_permission_async(request_type=dict) + + +def test_get_permission_field_headers(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.GetPermissionRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: + call.return_value = permission.Permission() + client.get_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_permission_field_headers_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.GetPermissionRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission.Permission() + ) + await client.get_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_get_permission_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = permission.Permission() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_permission( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_get_permission_flattened_error(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_permission( + permission_service.GetPermissionRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_permission_flattened_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = permission.Permission() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission.Permission() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_permission( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_get_permission_flattened_error_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_permission( + permission_service.GetPermissionRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + permission_service.ListPermissionsRequest, + dict, + ], +) +def test_list_permissions(request_type, transport: str = "grpc"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = permission_service.ListPermissionsResponse( + next_page_token="next_page_token_value", + ) + response = client.list_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = permission_service.ListPermissionsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListPermissionsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_permissions_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = permission_service.ListPermissionsRequest( + parent="parent_value", + page_token="page_token_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_permissions(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.ListPermissionsRequest( + parent="parent_value", + page_token="page_token_value", + ) + + +def test_list_permissions_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_permissions in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_permissions + ] = mock_rpc + request = {} + client.list_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_permissions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_permissions_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_permissions + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.list_permissions + ] = mock_rpc + + request = {} + await client.list_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.list_permissions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_permissions_async( + transport: str = "grpc_asyncio", + request_type=permission_service.ListPermissionsRequest, +): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission_service.ListPermissionsResponse( + next_page_token="next_page_token_value", + ) + ) + response = await client.list_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = permission_service.ListPermissionsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListPermissionsAsyncPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_permissions_async_from_dict(): + await test_list_permissions_async(request_type=dict) + + +def test_list_permissions_field_headers(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.ListPermissionsRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: + call.return_value = permission_service.ListPermissionsResponse() + client.list_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_permissions_field_headers_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.ListPermissionsRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission_service.ListPermissionsResponse() + ) + await client.list_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +def test_list_permissions_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = permission_service.ListPermissionsResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_permissions( + parent="parent_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + + +def test_list_permissions_flattened_error(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_permissions( + permission_service.ListPermissionsRequest(), + parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_permissions_flattened_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = permission_service.ListPermissionsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission_service.ListPermissionsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_permissions( + parent="parent_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_list_permissions_flattened_error_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_permissions( + permission_service.ListPermissionsRequest(), + parent="parent_value", + ) + + +def test_list_permissions_pager(transport_name: str = "grpc"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + permission.Permission(), + ], + next_page_token="abc", + ), + permission_service.ListPermissionsResponse( + permissions=[], + next_page_token="def", + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + ], + next_page_token="ghi", + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + ], + ), + RuntimeError, + ) + + expected_metadata = () + retry = retries.Retry() + timeout = 5 + expected_metadata = tuple(expected_metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_permissions(request={}, retry=retry, timeout=timeout) + + assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, permission.Permission) for i in results) + + +def test_list_permissions_pages(transport_name: str = "grpc"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + permission.Permission(), + ], + next_page_token="abc", + ), + permission_service.ListPermissionsResponse( + permissions=[], + next_page_token="def", + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + ], + next_page_token="ghi", + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + ], + ), + RuntimeError, + ) + pages = list(client.list_permissions(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_permissions_async_pager(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_permissions), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + permission.Permission(), + ], + next_page_token="abc", + ), + permission_service.ListPermissionsResponse( + permissions=[], + next_page_token="def", + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + ], + next_page_token="ghi", + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_permissions( + request={}, + ) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, permission.Permission) for i in responses) + + +@pytest.mark.asyncio +async def test_list_permissions_async_pages(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_permissions), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + permission.Permission(), + ], + next_page_token="abc", + ), + permission_service.ListPermissionsResponse( + permissions=[], + next_page_token="def", + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + ], + next_page_token="ghi", + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_permissions(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + permission_service.UpdatePermissionRequest, + dict, + ], +) +def test_update_permission(request_type, transport: str = "grpc"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gag_permission.Permission( + name="name_value", + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=gag_permission.Permission.Role.OWNER, + ) + response = client.update_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = permission_service.UpdatePermissionRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_permission.Permission) + assert response.name == "name_value" + assert response.grantee_type == gag_permission.Permission.GranteeType.USER + assert response.email_address == "email_address_value" + assert response.role == gag_permission.Permission.Role.OWNER + + +def test_update_permission_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = permission_service.UpdatePermissionRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.update_permission(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.UpdatePermissionRequest() + + +def test_update_permission_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_permission in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_permission + ] = mock_rpc + request = {} + client.update_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_permission(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_update_permission_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_permission + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.update_permission + ] = mock_rpc + + request = {} + await client.update_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.update_permission(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_update_permission_async( + transport: str = "grpc_asyncio", + request_type=permission_service.UpdatePermissionRequest, +): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_permission.Permission( + name="name_value", + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=gag_permission.Permission.Role.OWNER, + ) + ) + response = await client.update_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = permission_service.UpdatePermissionRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_permission.Permission) + assert response.name == "name_value" + assert response.grantee_type == gag_permission.Permission.GranteeType.USER + assert response.email_address == "email_address_value" + assert response.role == gag_permission.Permission.Role.OWNER + + +@pytest.mark.asyncio +async def test_update_permission_async_from_dict(): + await test_update_permission_async(request_type=dict) + + +def test_update_permission_field_headers(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.UpdatePermissionRequest() + + request.permission.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), "__call__" + ) as call: + call.return_value = gag_permission.Permission() + client.update_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "permission.name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_permission_field_headers_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.UpdatePermissionRequest() + + request.permission.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_permission.Permission() + ) + await client.update_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "permission.name=name_value", + ) in kw["metadata"] + + +def test_update_permission_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gag_permission.Permission() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_permission( + permission=gag_permission.Permission(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].permission + mock_val = gag_permission.Permission(name="name_value") + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +def test_update_permission_flattened_error(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_permission( + permission_service.UpdatePermissionRequest(), + permission=gag_permission.Permission(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_permission_flattened_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = gag_permission.Permission() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_permission.Permission() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_permission( + permission=gag_permission.Permission(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].permission + mock_val = gag_permission.Permission(name="name_value") + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_update_permission_flattened_error_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_permission( + permission_service.UpdatePermissionRequest(), + permission=gag_permission.Permission(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.parametrize( + "request_type", + [ + permission_service.DeletePermissionRequest, + dict, + ], +) +def test_delete_permission(request_type, transport: str = "grpc"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.delete_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = permission_service.DeletePermissionRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_permission_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = permission_service.DeletePermissionRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.delete_permission(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.DeletePermissionRequest( + name="name_value", + ) + + +def test_delete_permission_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_permission in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_permission + ] = mock_rpc + request = {} + client.delete_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_permission(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_permission_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_permission + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_permission + ] = mock_rpc + + request = {} + await client.delete_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.delete_permission(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_permission_async( + transport: str = "grpc_asyncio", + request_type=permission_service.DeletePermissionRequest, +): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.delete_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = permission_service.DeletePermissionRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_delete_permission_async_from_dict(): + await test_delete_permission_async(request_type=dict) + + +def test_delete_permission_field_headers(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.DeletePermissionRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), "__call__" + ) as call: + call.return_value = None + client.delete_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_permission_field_headers_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.DeletePermissionRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_delete_permission_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_permission( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_delete_permission_flattened_error(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_permission( + permission_service.DeletePermissionRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_permission_flattened_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_permission( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_delete_permission_flattened_error_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_permission( + permission_service.DeletePermissionRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + permission_service.TransferOwnershipRequest, + dict, + ], +) +def test_transfer_ownership(request_type, transport: str = "grpc"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.transfer_ownership), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = permission_service.TransferOwnershipResponse() + response = client.transfer_ownership(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = permission_service.TransferOwnershipRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, permission_service.TransferOwnershipResponse) + + +def test_transfer_ownership_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = permission_service.TransferOwnershipRequest( + name="name_value", + email_address="email_address_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.transfer_ownership), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.transfer_ownership(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == permission_service.TransferOwnershipRequest( + name="name_value", + email_address="email_address_value", + ) + + +def test_transfer_ownership_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.transfer_ownership in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.transfer_ownership + ] = mock_rpc + request = {} + client.transfer_ownership(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.transfer_ownership(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_transfer_ownership_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.transfer_ownership + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.transfer_ownership + ] = mock_rpc + + request = {} + await client.transfer_ownership(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.transfer_ownership(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_transfer_ownership_async( + transport: str = "grpc_asyncio", + request_type=permission_service.TransferOwnershipRequest, +): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.transfer_ownership), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission_service.TransferOwnershipResponse() + ) + response = await client.transfer_ownership(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = permission_service.TransferOwnershipRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, permission_service.TransferOwnershipResponse) + + +@pytest.mark.asyncio +async def test_transfer_ownership_async_from_dict(): + await test_transfer_ownership_async(request_type=dict) + + +def test_transfer_ownership_field_headers(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.TransferOwnershipRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.transfer_ownership), "__call__" + ) as call: + call.return_value = permission_service.TransferOwnershipResponse() + client.transfer_ownership(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_transfer_ownership_field_headers_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = permission_service.TransferOwnershipRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.transfer_ownership), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission_service.TransferOwnershipResponse() + ) + await client.transfer_ownership(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_create_permission_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_permission in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.create_permission + ] = mock_rpc + + request = {} + client.create_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_permission(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_create_permission_rest_required_fields( + request_type=permission_service.CreatePermissionRequest, +): + transport_class = transports.PermissionServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_permission._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_permission._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = gag_permission.Permission() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = gag_permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.create_permission(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_permission_rest_unset_required_fields(): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.create_permission._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "permission", + ) + ) + ) + + +def test_create_permission_rest_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = gag_permission.Permission() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "tunedModels/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + permission=gag_permission.Permission(name="name_value"), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = gag_permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.create_permission(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{parent=tunedModels/*}/permissions" % client.transport._host, + args[1], + ) + + +def test_create_permission_rest_flattened_error(transport: str = "rest"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_permission( + permission_service.CreatePermissionRequest(), + parent="parent_value", + permission=gag_permission.Permission(name="name_value"), + ) + + +def test_get_permission_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_permission in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_permission] = mock_rpc + + request = {} + client.get_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_permission(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_get_permission_rest_required_fields( + request_type=permission_service.GetPermissionRequest, +): + transport_class = transports.PermissionServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_permission._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_permission._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = permission.Permission() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_permission(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_permission_rest_unset_required_fields(): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_permission._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +def test_get_permission_rest_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = permission.Permission() + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "tunedModels/sample1/permissions/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.get_permission(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{name=tunedModels/*/permissions/*}" % client.transport._host, + args[1], + ) + + +def test_get_permission_rest_flattened_error(transport: str = "rest"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_permission( + permission_service.GetPermissionRequest(), + name="name_value", + ) + + +def test_list_permissions_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_permissions in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.list_permissions + ] = mock_rpc + + request = {} + client.list_permissions(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_permissions(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_list_permissions_rest_required_fields( + request_type=permission_service.ListPermissionsRequest, +): + transport_class = transports.PermissionServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_permissions._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_permissions._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "page_size", + "page_token", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = permission_service.ListPermissionsResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = permission_service.ListPermissionsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.list_permissions(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_list_permissions_rest_unset_required_fields(): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.list_permissions._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "pageSize", + "pageToken", + ) + ) + & set(("parent",)) + ) + + +def test_list_permissions_rest_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = permission_service.ListPermissionsResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "tunedModels/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = permission_service.ListPermissionsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.list_permissions(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{parent=tunedModels/*}/permissions" % client.transport._host, + args[1], + ) + + +def test_list_permissions_rest_flattened_error(transport: str = "rest"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_permissions( + permission_service.ListPermissionsRequest(), + parent="parent_value", + ) + + +def test_list_permissions_rest_pager(transport: str = "rest"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + permission.Permission(), + ], + next_page_token="abc", + ), + permission_service.ListPermissionsResponse( + permissions=[], + next_page_token="def", + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + ], + next_page_token="ghi", + ), + permission_service.ListPermissionsResponse( + permissions=[ + permission.Permission(), + permission.Permission(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + permission_service.ListPermissionsResponse.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"parent": "tunedModels/sample1"} + + pager = client.list_permissions(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, permission.Permission) for i in results) + + pages = list(client.list_permissions(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_update_permission_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_permission in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.update_permission + ] = mock_rpc + + request = {} + client.update_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_permission(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_update_permission_rest_required_fields( + request_type=permission_service.UpdatePermissionRequest, +): + transport_class = transports.PermissionServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_permission._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_permission._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("update_mask",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = gag_permission.Permission() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "patch", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = gag_permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.update_permission(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_update_permission_rest_unset_required_fields(): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.update_permission._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(("updateMask",)) + & set( + ( + "permission", + "updateMask", + ) + ) + ) + + +def test_update_permission_rest_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = gag_permission.Permission() + + # get arguments that satisfy an http rule for this method + sample_request = { + "permission": {"name": "tunedModels/sample1/permissions/sample2"} + } + + # get truthy value for each flattened field + mock_args = dict( + permission=gag_permission.Permission(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = gag_permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.update_permission(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{permission.name=tunedModels/*/permissions/*}" + % client.transport._host, + args[1], + ) + + +def test_update_permission_rest_flattened_error(transport: str = "rest"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_permission( + permission_service.UpdatePermissionRequest(), + permission=gag_permission.Permission(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +def test_delete_permission_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_permission in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.delete_permission + ] = mock_rpc + + request = {} + client.delete_permission(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_permission(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_delete_permission_rest_required_fields( + request_type=permission_service.DeletePermissionRequest, +): + transport_class = transports.PermissionServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_permission._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_permission._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.delete_permission(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_delete_permission_rest_unset_required_fields(): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.delete_permission._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +def test_delete_permission_rest_flattened(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "tunedModels/sample1/permissions/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.delete_permission(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{name=tunedModels/*/permissions/*}" % client.transport._host, + args[1], + ) + + +def test_delete_permission_rest_flattened_error(transport: str = "rest"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_permission( + permission_service.DeletePermissionRequest(), + name="name_value", + ) + + +def test_transfer_ownership_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.transfer_ownership in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.transfer_ownership + ] = mock_rpc + + request = {} + client.transfer_ownership(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.transfer_ownership(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_transfer_ownership_rest_required_fields( + request_type=permission_service.TransferOwnershipRequest, +): + transport_class = transports.PermissionServiceRestTransport + + request_init = {} + request_init["name"] = "" + request_init["email_address"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).transfer_ownership._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + jsonified_request["emailAddress"] = "email_address_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).transfer_ownership._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + assert "emailAddress" in jsonified_request + assert jsonified_request["emailAddress"] == "email_address_value" + + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = permission_service.TransferOwnershipResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = permission_service.TransferOwnershipResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.transfer_ownership(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_transfer_ownership_rest_unset_required_fields(): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.transfer_ownership._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "name", + "emailAddress", + ) + ) + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.PermissionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.PermissionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PermissionServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.PermissionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = PermissionServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = PermissionServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.PermissionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PermissionServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.PermissionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = PermissionServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.PermissionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.PermissionServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PermissionServiceGrpcTransport, + transports.PermissionServiceGrpcAsyncIOTransport, + transports.PermissionServiceRestTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_kind_grpc(): + transport = PermissionServiceClient.get_transport_class("grpc")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "grpc" + + +def test_initialize_client_w_grpc(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_create_permission_empty_call_grpc(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), "__call__" + ) as call: + call.return_value = gag_permission.Permission() + client.create_permission(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.CreatePermissionRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_permission_empty_call_grpc(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: + call.return_value = permission.Permission() + client.get_permission(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.GetPermissionRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_permissions_empty_call_grpc(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: + call.return_value = permission_service.ListPermissionsResponse() + client.list_permissions(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.ListPermissionsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_update_permission_empty_call_grpc(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), "__call__" + ) as call: + call.return_value = gag_permission.Permission() + client.update_permission(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.UpdatePermissionRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_delete_permission_empty_call_grpc(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), "__call__" + ) as call: + call.return_value = None + client.delete_permission(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.DeletePermissionRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_transfer_ownership_empty_call_grpc(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.transfer_ownership), "__call__" + ) as call: + call.return_value = permission_service.TransferOwnershipResponse() + client.transfer_ownership(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.TransferOwnershipRequest() + + assert args[0] == request_msg + + +def test_transport_kind_grpc_asyncio(): + transport = PermissionServiceAsyncClient.get_transport_class("grpc_asyncio")( + credentials=async_anonymous_credentials() + ) + assert transport.kind == "grpc_asyncio" + + +def test_initialize_client_w_grpc_asyncio(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_create_permission_empty_call_grpc_asyncio(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_permission.Permission( + name="name_value", + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=gag_permission.Permission.Role.OWNER, + ) + ) + await client.create_permission(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.CreatePermissionRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_get_permission_empty_call_grpc_asyncio(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission.Permission( + name="name_value", + grantee_type=permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=permission.Permission.Role.OWNER, + ) + ) + await client.get_permission(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.GetPermissionRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_list_permissions_empty_call_grpc_asyncio(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission_service.ListPermissionsResponse( + next_page_token="next_page_token_value", + ) + ) + await client.list_permissions(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.ListPermissionsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_update_permission_empty_call_grpc_asyncio(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + gag_permission.Permission( + name="name_value", + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=gag_permission.Permission.Role.OWNER, + ) + ) + await client.update_permission(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.UpdatePermissionRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_delete_permission_empty_call_grpc_asyncio(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_permission(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.DeletePermissionRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_transfer_ownership_empty_call_grpc_asyncio(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.transfer_ownership), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + permission_service.TransferOwnershipResponse() + ) + await client.transfer_ownership(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.TransferOwnershipRequest() + + assert args[0] == request_msg + + +def test_transport_kind_rest(): + transport = PermissionServiceClient.get_transport_class("rest")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "rest" + + +def test_create_permission_rest_bad_request( + request_type=permission_service.CreatePermissionRequest, +): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"parent": "tunedModels/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.create_permission(request) + + +@pytest.mark.parametrize( + "request_type", + [ + permission_service.CreatePermissionRequest, + dict, + ], +) +def test_create_permission_rest_call_success(request_type): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "tunedModels/sample1"} + request_init["permission"] = { + "name": "name_value", + "grantee_type": 1, + "email_address": "email_address_value", + "role": 1, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = permission_service.CreatePermissionRequest.meta.fields["permission"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["permission"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["permission"][field])): + del request_init["permission"][field][i][subfield] + else: + del request_init["permission"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = gag_permission.Permission( + name="name_value", + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=gag_permission.Permission.Role.OWNER, + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = gag_permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.create_permission(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_permission.Permission) + assert response.name == "name_value" + assert response.grantee_type == gag_permission.Permission.GranteeType.USER + assert response.email_address == "email_address_value" + assert response.role == gag_permission.Permission.Role.OWNER + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_permission_rest_interceptors(null_interceptor): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.PermissionServiceRestInterceptor(), + ) + client = PermissionServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.PermissionServiceRestInterceptor, "post_create_permission" + ) as post, mock.patch.object( + transports.PermissionServiceRestInterceptor, "pre_create_permission" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = permission_service.CreatePermissionRequest.pb( + permission_service.CreatePermissionRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = gag_permission.Permission.to_json(gag_permission.Permission()) + req.return_value.content = return_value + + request = permission_service.CreatePermissionRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = gag_permission.Permission() + + client.create_permission( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_permission_rest_bad_request( + request_type=permission_service.GetPermissionRequest, +): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "tunedModels/sample1/permissions/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_permission(request) + + +@pytest.mark.parametrize( + "request_type", + [ + permission_service.GetPermissionRequest, + dict, + ], +) +def test_get_permission_rest_call_success(request_type): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "tunedModels/sample1/permissions/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = permission.Permission( + name="name_value", + grantee_type=permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=permission.Permission.Role.OWNER, + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.get_permission(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, permission.Permission) + assert response.name == "name_value" + assert response.grantee_type == permission.Permission.GranteeType.USER + assert response.email_address == "email_address_value" + assert response.role == permission.Permission.Role.OWNER + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_permission_rest_interceptors(null_interceptor): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.PermissionServiceRestInterceptor(), + ) + client = PermissionServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.PermissionServiceRestInterceptor, "post_get_permission" + ) as post, mock.patch.object( + transports.PermissionServiceRestInterceptor, "pre_get_permission" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = permission_service.GetPermissionRequest.pb( + permission_service.GetPermissionRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = permission.Permission.to_json(permission.Permission()) + req.return_value.content = return_value + + request = permission_service.GetPermissionRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = permission.Permission() + + client.get_permission( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_list_permissions_rest_bad_request( + request_type=permission_service.ListPermissionsRequest, +): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"parent": "tunedModels/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_permissions(request) + + +@pytest.mark.parametrize( + "request_type", + [ + permission_service.ListPermissionsRequest, + dict, + ], +) +def test_list_permissions_rest_call_success(request_type): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "tunedModels/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = permission_service.ListPermissionsResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = permission_service.ListPermissionsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.list_permissions(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListPermissionsPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_permissions_rest_interceptors(null_interceptor): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.PermissionServiceRestInterceptor(), + ) + client = PermissionServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.PermissionServiceRestInterceptor, "post_list_permissions" + ) as post, mock.patch.object( + transports.PermissionServiceRestInterceptor, "pre_list_permissions" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = permission_service.ListPermissionsRequest.pb( + permission_service.ListPermissionsRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = permission_service.ListPermissionsResponse.to_json( + permission_service.ListPermissionsResponse() + ) + req.return_value.content = return_value + + request = permission_service.ListPermissionsRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = permission_service.ListPermissionsResponse() + + client.list_permissions( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_update_permission_rest_bad_request( + request_type=permission_service.UpdatePermissionRequest, +): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"permission": {"name": "tunedModels/sample1/permissions/sample2"}} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.update_permission(request) + + +@pytest.mark.parametrize( + "request_type", + [ + permission_service.UpdatePermissionRequest, + dict, + ], +) +def test_update_permission_rest_call_success(request_type): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"permission": {"name": "tunedModels/sample1/permissions/sample2"}} + request_init["permission"] = { + "name": "tunedModels/sample1/permissions/sample2", + "grantee_type": 1, + "email_address": "email_address_value", + "role": 1, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = permission_service.UpdatePermissionRequest.meta.fields["permission"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["permission"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["permission"][field])): + del request_init["permission"][field][i][subfield] + else: + del request_init["permission"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = gag_permission.Permission( + name="name_value", + grantee_type=gag_permission.Permission.GranteeType.USER, + email_address="email_address_value", + role=gag_permission.Permission.Role.OWNER, + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = gag_permission.Permission.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.update_permission(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, gag_permission.Permission) + assert response.name == "name_value" + assert response.grantee_type == gag_permission.Permission.GranteeType.USER + assert response.email_address == "email_address_value" + assert response.role == gag_permission.Permission.Role.OWNER + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_update_permission_rest_interceptors(null_interceptor): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.PermissionServiceRestInterceptor(), + ) + client = PermissionServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.PermissionServiceRestInterceptor, "post_update_permission" + ) as post, mock.patch.object( + transports.PermissionServiceRestInterceptor, "pre_update_permission" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = permission_service.UpdatePermissionRequest.pb( + permission_service.UpdatePermissionRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = gag_permission.Permission.to_json(gag_permission.Permission()) + req.return_value.content = return_value + + request = permission_service.UpdatePermissionRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = gag_permission.Permission() + + client.update_permission( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_delete_permission_rest_bad_request( + request_type=permission_service.DeletePermissionRequest, +): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "tunedModels/sample1/permissions/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.delete_permission(request) + + +@pytest.mark.parametrize( + "request_type", + [ + permission_service.DeletePermissionRequest, + dict, + ], +) +def test_delete_permission_rest_call_success(request_type): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "tunedModels/sample1/permissions/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = "" + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.delete_permission(request) + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_delete_permission_rest_interceptors(null_interceptor): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.PermissionServiceRestInterceptor(), + ) + client = PermissionServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.PermissionServiceRestInterceptor, "pre_delete_permission" + ) as pre: + pre.assert_not_called() + pb_message = permission_service.DeletePermissionRequest.pb( + permission_service.DeletePermissionRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + request = permission_service.DeletePermissionRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.delete_permission( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + + +def test_transfer_ownership_rest_bad_request( + request_type=permission_service.TransferOwnershipRequest, +): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "tunedModels/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.transfer_ownership(request) + + +@pytest.mark.parametrize( + "request_type", + [ + permission_service.TransferOwnershipRequest, + dict, + ], +) +def test_transfer_ownership_rest_call_success(request_type): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "tunedModels/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = permission_service.TransferOwnershipResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = permission_service.TransferOwnershipResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.transfer_ownership(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, permission_service.TransferOwnershipResponse) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_transfer_ownership_rest_interceptors(null_interceptor): + transport = transports.PermissionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.PermissionServiceRestInterceptor(), + ) + client = PermissionServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.PermissionServiceRestInterceptor, "post_transfer_ownership" + ) as post, mock.patch.object( + transports.PermissionServiceRestInterceptor, "pre_transfer_ownership" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = permission_service.TransferOwnershipRequest.pb( + permission_service.TransferOwnershipRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = permission_service.TransferOwnershipResponse.to_json( + permission_service.TransferOwnershipResponse() + ) + req.return_value.content = return_value + + request = permission_service.TransferOwnershipRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = permission_service.TransferOwnershipResponse() + + client.transfer_ownership( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_operation_rest_bad_request( + request_type=operations_pb2.GetOperationRequest, +): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict( + {"name": "tunedModels/sample1/operations/sample2"}, request + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.GetOperationRequest, + dict, + ], +) +def test_get_operation_rest(request_type): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1/operations/sample2"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_operation(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_list_operations_rest_bad_request( + request_type=operations_pb2.ListOperationsRequest, +): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict({"name": "tunedModels/sample1"}, request) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_operations(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.ListOperationsRequest, + dict, + ], +) +def test_list_operations_rest(request_type): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.ListOperationsResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.list_operations(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_initialize_client_w_rest(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_create_permission_empty_call_rest(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.create_permission), "__call__" + ) as call: + client.create_permission(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.CreatePermissionRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_permission_empty_call_rest(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_permission), "__call__") as call: + client.get_permission(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.GetPermissionRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_permissions_empty_call_rest(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_permissions), "__call__") as call: + client.list_permissions(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.ListPermissionsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_update_permission_empty_call_rest(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.update_permission), "__call__" + ) as call: + client.update_permission(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.UpdatePermissionRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_delete_permission_empty_call_rest(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.delete_permission), "__call__" + ) as call: + client.delete_permission(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.DeletePermissionRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_transfer_ownership_empty_call_rest(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.transfer_ownership), "__call__" + ) as call: + client.transfer_ownership(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = permission_service.TransferOwnershipRequest() + + assert args[0] == request_msg + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.PermissionServiceGrpcTransport, + ) + + +def test_permission_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.PermissionServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_permission_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.permission_service.transports.PermissionServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.PermissionServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_permission", + "get_permission", + "list_permissions", + "update_permission", + "delete_permission", + "transfer_ownership", + "get_operation", + "list_operations", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + "kind", + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_permission_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.ai.generativelanguage_v1alpha.services.permission_service.transports.PermissionServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.PermissionServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=(), + quota_project_id="octopus", + ) + + +def test_permission_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.ai.generativelanguage_v1alpha.services.permission_service.transports.PermissionServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.PermissionServiceTransport() + adc.assert_called_once() + + +def test_permission_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + PermissionServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=(), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PermissionServiceGrpcTransport, + transports.PermissionServiceGrpcAsyncIOTransport, + ], +) +def test_permission_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PermissionServiceGrpcTransport, + transports.PermissionServiceGrpcAsyncIOTransport, + transports.PermissionServiceRestTransport, + ], +) +def test_permission_service_transport_auth_gdch_credentials(transport_class): + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, "default", autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with(e) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.PermissionServiceGrpcTransport, grpc_helpers), + (transports.PermissionServiceGrpcAsyncIOTransport, grpc_helpers_async), + ], +) +def test_permission_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=(), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PermissionServiceGrpcTransport, + transports.PermissionServiceGrpcAsyncIOTransport, + ], +) +def test_permission_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_permission_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.PermissionServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_permission_service_host_no_port(transport_name): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_permission_service_host_with_port(transport_name): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com:8000" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_permission_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = PermissionServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = PermissionServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.create_permission._session + session2 = client2.transport.create_permission._session + assert session1 != session2 + session1 = client1.transport.get_permission._session + session2 = client2.transport.get_permission._session + assert session1 != session2 + session1 = client1.transport.list_permissions._session + session2 = client2.transport.list_permissions._session + assert session1 != session2 + session1 = client1.transport.update_permission._session + session2 = client2.transport.update_permission._session + assert session1 != session2 + session1 = client1.transport.delete_permission._session + session2 = client2.transport.delete_permission._session + assert session1 != session2 + session1 = client1.transport.transfer_ownership._session + session2 = client2.transport.transfer_ownership._session + assert session1 != session2 + + +def test_permission_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.PermissionServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_permission_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.PermissionServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.PermissionServiceGrpcTransport, + transports.PermissionServiceGrpcAsyncIOTransport, + ], +) +def test_permission_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.PermissionServiceGrpcTransport, + transports.PermissionServiceGrpcAsyncIOTransport, + ], +) +def test_permission_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_permission_path(): + tuned_model = "squid" + permission = "clam" + expected = "tunedModels/{tuned_model}/permissions/{permission}".format( + tuned_model=tuned_model, + permission=permission, + ) + actual = PermissionServiceClient.permission_path(tuned_model, permission) + assert expected == actual + + +def test_parse_permission_path(): + expected = { + "tuned_model": "whelk", + "permission": "octopus", + } + path = PermissionServiceClient.permission_path(**expected) + + # Check that the path construction is reversible. + actual = PermissionServiceClient.parse_permission_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "oyster" + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = PermissionServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "nudibranch", + } + path = PermissionServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = PermissionServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "cuttlefish" + expected = "folders/{folder}".format( + folder=folder, + ) + actual = PermissionServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "mussel", + } + path = PermissionServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = PermissionServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "winkle" + expected = "organizations/{organization}".format( + organization=organization, + ) + actual = PermissionServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nautilus", + } + path = PermissionServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = PermissionServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "scallop" + expected = "projects/{project}".format( + project=project, + ) + actual = PermissionServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "abalone", + } + path = PermissionServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = PermissionServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "squid" + location = "clam" + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + actual = PermissionServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "whelk", + "location": "octopus", + } + path = PermissionServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = PermissionServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.PermissionServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.PermissionServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = PermissionServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + +def test_get_operation(transport: str = "grpc"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + response = client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +@pytest.mark.asyncio +async def test_get_operation_async(transport: str = "grpc_asyncio"): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_get_operation_field_headers(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = operations_pb2.Operation() + + client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_operation_field_headers_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_get_operation_from_dict(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + + response = client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_operation_from_dict_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_list_operations(transport: str = "grpc"): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + response = client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +@pytest.mark.asyncio +async def test_list_operations_async(transport: str = "grpc_asyncio"): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_list_operations_field_headers(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = operations_pb2.ListOperationsResponse() + + client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_operations_field_headers_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_list_operations_from_dict(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + + response = client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_list_operations_from_dict_async(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_transport_close_grpc(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +@pytest.mark.asyncio +async def test_transport_close_grpc_asyncio(): + client = PermissionServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close_rest(): + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + with mock.patch.object( + type(getattr(client.transport, "_session")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "rest", + "grpc", + ] + for transport in transports: + client = PermissionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (PermissionServiceClient, transports.PermissionServiceGrpcTransport), + ( + PermissionServiceAsyncClient, + transports.PermissionServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_prediction_service.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_prediction_service.py new file mode 100644 index 000000000000..452a4da5caf2 --- /dev/null +++ b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_prediction_service.py @@ -0,0 +1,3001 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +from collections.abc import AsyncIterable, Iterable +import json +import math + +from google.api_core import api_core_version +from google.protobuf import json_format +import grpc +from grpc.experimental import aio +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule, TimestampRule +import pytest +from requests import PreparedRequest, Request, Response +from requests.sessions import Session + +try: + from google.auth.aio import credentials as ga_credentials_async + + HAS_GOOGLE_AUTH_AIO = True +except ImportError: # pragma: NO COVER + HAS_GOOGLE_AUTH_AIO = False + +from google.api_core import gapic_v1, grpc_helpers, grpc_helpers_async, path_template +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +import google.auth +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account +from google.protobuf import struct_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.prediction_service import ( + PredictionServiceAsyncClient, + PredictionServiceClient, + transports, +) +from google.ai.generativelanguage_v1alpha.types import prediction_service + + +async def mock_async_gen(data, chunk_size=1): + for i in range(0, len(data)): # pragma: NO COVER + chunk = data[i : i + chunk_size] + yield chunk.encode("utf-8") + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# TODO: use async auth anon credentials by default once the minimum version of google-auth is upgraded. +# See related issue: https://github.com/googleapis/gapic-generator-python/issues/2107. +def async_anonymous_credentials(): + if HAS_GOOGLE_AUTH_AIO: + return ga_credentials_async.AnonymousCredentials() + return ga_credentials.AnonymousCredentials() + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +# If default endpoint template is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint template so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint_template(client): + return ( + "test.{UNIVERSE_DOMAIN}" + if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE) + else client._DEFAULT_ENDPOINT_TEMPLATE + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert PredictionServiceClient._get_default_mtls_endpoint(None) is None + assert ( + PredictionServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + PredictionServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + PredictionServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PredictionServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + PredictionServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) + + +def test__read_environment_variables(): + assert PredictionServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert PredictionServiceClient._read_environment_variables() == ( + True, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + assert PredictionServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + PredictionServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + assert PredictionServiceClient._read_environment_variables() == ( + False, + "never", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + assert PredictionServiceClient._read_environment_variables() == ( + False, + "always", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + assert PredictionServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + PredictionServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}): + assert PredictionServiceClient._read_environment_variables() == ( + False, + "auto", + "foo.com", + ) + + +def test__get_client_cert_source(): + mock_provided_cert_source = mock.Mock() + mock_default_cert_source = mock.Mock() + + assert PredictionServiceClient._get_client_cert_source(None, False) is None + assert ( + PredictionServiceClient._get_client_cert_source( + mock_provided_cert_source, False + ) + is None + ) + assert ( + PredictionServiceClient._get_client_cert_source(mock_provided_cert_source, True) + == mock_provided_cert_source + ) + + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", return_value=True + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_default_cert_source, + ): + assert ( + PredictionServiceClient._get_client_cert_source(None, True) + is mock_default_cert_source + ) + assert ( + PredictionServiceClient._get_client_cert_source( + mock_provided_cert_source, "true" + ) + is mock_provided_cert_source + ) + + +@mock.patch.object( + PredictionServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PredictionServiceClient), +) +@mock.patch.object( + PredictionServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PredictionServiceAsyncClient), +) +def test__get_api_endpoint(): + api_override = "foo.com" + mock_client_cert_source = mock.Mock() + default_universe = PredictionServiceClient._DEFAULT_UNIVERSE + default_endpoint = PredictionServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = PredictionServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + assert ( + PredictionServiceClient._get_api_endpoint( + api_override, mock_client_cert_source, default_universe, "always" + ) + == api_override + ) + assert ( + PredictionServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "auto" + ) + == PredictionServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + PredictionServiceClient._get_api_endpoint(None, None, default_universe, "auto") + == default_endpoint + ) + assert ( + PredictionServiceClient._get_api_endpoint( + None, None, default_universe, "always" + ) + == PredictionServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + PredictionServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "always" + ) + == PredictionServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + PredictionServiceClient._get_api_endpoint(None, None, mock_universe, "never") + == mock_endpoint + ) + assert ( + PredictionServiceClient._get_api_endpoint(None, None, default_universe, "never") + == default_endpoint + ) + + with pytest.raises(MutualTLSChannelError) as excinfo: + PredictionServiceClient._get_api_endpoint( + None, mock_client_cert_source, mock_universe, "auto" + ) + assert ( + str(excinfo.value) + == "mTLS is not supported in any universe other than googleapis.com." + ) + + +def test__get_universe_domain(): + client_universe_domain = "foo.com" + universe_domain_env = "bar.com" + + assert ( + PredictionServiceClient._get_universe_domain( + client_universe_domain, universe_domain_env + ) + == client_universe_domain + ) + assert ( + PredictionServiceClient._get_universe_domain(None, universe_domain_env) + == universe_domain_env + ) + assert ( + PredictionServiceClient._get_universe_domain(None, None) + == PredictionServiceClient._DEFAULT_UNIVERSE + ) + + with pytest.raises(ValueError) as excinfo: + PredictionServiceClient._get_universe_domain("", None) + assert str(excinfo.value) == "Universe Domain cannot be an empty string." + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (PredictionServiceClient, "grpc"), + (PredictionServiceAsyncClient, "grpc_asyncio"), + (PredictionServiceClient, "rest"), + ], +) +def test_prediction_service_client_from_service_account_info( + client_class, transport_name +): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.PredictionServiceGrpcTransport, "grpc"), + (transports.PredictionServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.PredictionServiceRestTransport, "rest"), + ], +) +def test_prediction_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (PredictionServiceClient, "grpc"), + (PredictionServiceAsyncClient, "grpc_asyncio"), + (PredictionServiceClient, "rest"), + ], +) +def test_prediction_service_client_from_service_account_file( + client_class, transport_name +): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +def test_prediction_service_client_get_transport_class(): + transport = PredictionServiceClient.get_transport_class() + available_transports = [ + transports.PredictionServiceGrpcTransport, + transports.PredictionServiceRestTransport, + ] + assert transport in available_transports + + transport = PredictionServiceClient.get_transport_class("grpc") + assert transport == transports.PredictionServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (PredictionServiceClient, transports.PredictionServiceRestTransport, "rest"), + ], +) +@mock.patch.object( + PredictionServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PredictionServiceClient), +) +@mock.patch.object( + PredictionServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PredictionServiceAsyncClient), +) +def test_prediction_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(PredictionServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(PredictionServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + PredictionServiceClient, + transports.PredictionServiceGrpcTransport, + "grpc", + "true", + ), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + PredictionServiceClient, + transports.PredictionServiceGrpcTransport, + "grpc", + "false", + ), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ( + PredictionServiceClient, + transports.PredictionServiceRestTransport, + "rest", + "true", + ), + ( + PredictionServiceClient, + transports.PredictionServiceRestTransport, + "rest", + "false", + ), + ], +) +@mock.patch.object( + PredictionServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PredictionServiceClient), +) +@mock.patch.object( + PredictionServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PredictionServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_prediction_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class", [PredictionServiceClient, PredictionServiceAsyncClient] +) +@mock.patch.object( + PredictionServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PredictionServiceClient), +) +@mock.patch.object( + PredictionServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(PredictionServiceAsyncClient), +) +def test_prediction_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + +@pytest.mark.parametrize( + "client_class", [PredictionServiceClient, PredictionServiceAsyncClient] +) +@mock.patch.object( + PredictionServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PredictionServiceClient), +) +@mock.patch.object( + PredictionServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(PredictionServiceAsyncClient), +) +def test_prediction_service_client_client_api_endpoint(client_class): + mock_client_cert_source = client_cert_source_callback + api_override = "foo.com" + default_universe = PredictionServiceClient._DEFAULT_UNIVERSE + default_endpoint = PredictionServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = PredictionServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true", + # use ClientOptions.api_endpoint as the api endpoint regardless. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ): + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=api_override + ) + client = client_class( + client_options=options, + credentials=ga_credentials.AnonymousCredentials(), + ) + assert client.api_endpoint == api_override + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == default_endpoint + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always", + # use the DEFAULT_MTLS_ENDPOINT as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + + # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default), + # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist, + # and ClientOptions.universe_domain="bar.com", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint. + options = client_options.ClientOptions() + universe_exists = hasattr(options, "universe_domain") + if universe_exists: + options = client_options.ClientOptions(universe_domain=mock_universe) + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + else: + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == ( + mock_endpoint if universe_exists else default_endpoint + ) + assert client.universe_domain == ( + mock_universe if universe_exists else default_universe + ) + + # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + options = client_options.ClientOptions() + if hasattr(options, "universe_domain"): + delattr(options, "universe_domain") + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == default_endpoint + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (PredictionServiceClient, transports.PredictionServiceGrpcTransport, "grpc"), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (PredictionServiceClient, transports.PredictionServiceRestTransport, "rest"), + ], +) +def test_prediction_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + PredictionServiceClient, + transports.PredictionServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ( + PredictionServiceClient, + transports.PredictionServiceRestTransport, + "rest", + None, + ), + ], +) +def test_prediction_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +def test_prediction_service_client_client_options_from_dict(): + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.prediction_service.transports.PredictionServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = PredictionServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + PredictionServiceClient, + transports.PredictionServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_prediction_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=(), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + prediction_service.PredictRequest, + dict, + ], +) +def test_predict(request_type, transport: str = "grpc"): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.PredictResponse() + response = client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = prediction_service.PredictRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, prediction_service.PredictResponse) + + +def test_predict_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = prediction_service.PredictRequest( + model="model_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.predict(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == prediction_service.PredictRequest( + model="model_value", + ) + + +def test_predict_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.predict] = mock_rpc + request = {} + client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_predict_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = PredictionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.predict + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.predict + ] = mock_rpc + + request = {} + await client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_predict_async( + transport: str = "grpc_asyncio", request_type=prediction_service.PredictRequest +): + client = PredictionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.PredictResponse() + ) + response = await client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = prediction_service.PredictRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, prediction_service.PredictResponse) + + +@pytest.mark.asyncio +async def test_predict_async_from_dict(): + await test_predict_async(request_type=dict) + + +def test_predict_field_headers(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = prediction_service.PredictRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + call.return_value = prediction_service.PredictResponse() + client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_predict_field_headers_async(): + client = PredictionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = prediction_service.PredictRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.PredictResponse() + ) + await client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +def test_predict_flattened(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.PredictResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.predict( + model="model_value", + instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].instances + mock_val = [struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)] + assert arg == mock_val + + +def test_predict_flattened_error(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.predict( + prediction_service.PredictRequest(), + model="model_value", + instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], + ) + + +@pytest.mark.asyncio +async def test_predict_flattened_async(): + client = PredictionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = prediction_service.PredictResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.PredictResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.predict( + model="model_value", + instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].instances + mock_val = [struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)] + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_predict_flattened_error_async(): + client = PredictionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.predict( + prediction_service.PredictRequest(), + model="model_value", + instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], + ) + + +def test_predict_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.predict in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.predict] = mock_rpc + + request = {} + client.predict(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.predict(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_predict_rest_required_fields(request_type=prediction_service.PredictRequest): + transport_class = transports.PredictionServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).predict._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = "model_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).predict._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == "model_value" + + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = prediction_service.PredictResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = prediction_service.PredictResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.predict(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_predict_rest_unset_required_fields(): + transport = transports.PredictionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.predict._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "instances", + ) + ) + ) + + +def test_predict_rest_flattened(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = prediction_service.PredictResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"model": "models/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + model="model_value", + instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = prediction_service.PredictResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.predict(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{model=models/*}:predict" % client.transport._host, args[1] + ) + + +def test_predict_rest_flattened_error(transport: str = "rest"): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.predict( + prediction_service.PredictRequest(), + model="model_value", + instances=[struct_pb2.Value(null_value=struct_pb2.NullValue.NULL_VALUE)], + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PredictionServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = PredictionServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = PredictionServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = PredictionServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = PredictionServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.PredictionServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.PredictionServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PredictionServiceGrpcTransport, + transports.PredictionServiceGrpcAsyncIOTransport, + transports.PredictionServiceRestTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_kind_grpc(): + transport = PredictionServiceClient.get_transport_class("grpc")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "grpc" + + +def test_initialize_client_w_grpc(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_predict_empty_call_grpc(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + call.return_value = prediction_service.PredictResponse() + client.predict(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = prediction_service.PredictRequest() + + assert args[0] == request_msg + + +def test_transport_kind_grpc_asyncio(): + transport = PredictionServiceAsyncClient.get_transport_class("grpc_asyncio")( + credentials=async_anonymous_credentials() + ) + assert transport.kind == "grpc_asyncio" + + +def test_initialize_client_w_grpc_asyncio(): + client = PredictionServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_predict_empty_call_grpc_asyncio(): + client = PredictionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + prediction_service.PredictResponse() + ) + await client.predict(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = prediction_service.PredictRequest() + + assert args[0] == request_msg + + +def test_transport_kind_rest(): + transport = PredictionServiceClient.get_transport_class("rest")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "rest" + + +def test_predict_rest_bad_request(request_type=prediction_service.PredictRequest): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.predict(request) + + +@pytest.mark.parametrize( + "request_type", + [ + prediction_service.PredictRequest, + dict, + ], +) +def test_predict_rest_call_success(request_type): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = prediction_service.PredictResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = prediction_service.PredictResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.predict(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, prediction_service.PredictResponse) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_predict_rest_interceptors(null_interceptor): + transport = transports.PredictionServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.PredictionServiceRestInterceptor(), + ) + client = PredictionServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.PredictionServiceRestInterceptor, "post_predict" + ) as post, mock.patch.object( + transports.PredictionServiceRestInterceptor, "pre_predict" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = prediction_service.PredictRequest.pb( + prediction_service.PredictRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = prediction_service.PredictResponse.to_json( + prediction_service.PredictResponse() + ) + req.return_value.content = return_value + + request = prediction_service.PredictRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = prediction_service.PredictResponse() + + client.predict( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_operation_rest_bad_request( + request_type=operations_pb2.GetOperationRequest, +): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict( + {"name": "tunedModels/sample1/operations/sample2"}, request + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.GetOperationRequest, + dict, + ], +) +def test_get_operation_rest(request_type): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1/operations/sample2"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_operation(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_list_operations_rest_bad_request( + request_type=operations_pb2.ListOperationsRequest, +): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict({"name": "tunedModels/sample1"}, request) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_operations(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.ListOperationsRequest, + dict, + ], +) +def test_list_operations_rest(request_type): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.ListOperationsResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.list_operations(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_initialize_client_w_rest(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_predict_empty_call_rest(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.predict), "__call__") as call: + client.predict(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = prediction_service.PredictRequest() + + assert args[0] == request_msg + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.PredictionServiceGrpcTransport, + ) + + +def test_prediction_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.PredictionServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_prediction_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.prediction_service.transports.PredictionServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.PredictionServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "predict", + "get_operation", + "list_operations", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + "kind", + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_prediction_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.ai.generativelanguage_v1alpha.services.prediction_service.transports.PredictionServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.PredictionServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=(), + quota_project_id="octopus", + ) + + +def test_prediction_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.ai.generativelanguage_v1alpha.services.prediction_service.transports.PredictionServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.PredictionServiceTransport() + adc.assert_called_once() + + +def test_prediction_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + PredictionServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=(), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PredictionServiceGrpcTransport, + transports.PredictionServiceGrpcAsyncIOTransport, + ], +) +def test_prediction_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PredictionServiceGrpcTransport, + transports.PredictionServiceGrpcAsyncIOTransport, + transports.PredictionServiceRestTransport, + ], +) +def test_prediction_service_transport_auth_gdch_credentials(transport_class): + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, "default", autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with(e) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.PredictionServiceGrpcTransport, grpc_helpers), + (transports.PredictionServiceGrpcAsyncIOTransport, grpc_helpers_async), + ], +) +def test_prediction_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=(), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.PredictionServiceGrpcTransport, + transports.PredictionServiceGrpcAsyncIOTransport, + ], +) +def test_prediction_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_prediction_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.PredictionServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_prediction_service_host_no_port(transport_name): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_prediction_service_host_with_port(transport_name): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com:8000" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_prediction_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = PredictionServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = PredictionServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.predict._session + session2 = client2.transport.predict._session + assert session1 != session2 + + +def test_prediction_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.PredictionServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_prediction_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.PredictionServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.PredictionServiceGrpcTransport, + transports.PredictionServiceGrpcAsyncIOTransport, + ], +) +def test_prediction_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.PredictionServiceGrpcTransport, + transports.PredictionServiceGrpcAsyncIOTransport, + ], +) +def test_prediction_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_model_path(): + model = "squid" + expected = "models/{model}".format( + model=model, + ) + actual = PredictionServiceClient.model_path(model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "model": "clam", + } + path = PredictionServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_model_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "whelk" + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = PredictionServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "octopus", + } + path = PredictionServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "oyster" + expected = "folders/{folder}".format( + folder=folder, + ) + actual = PredictionServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nudibranch", + } + path = PredictionServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "cuttlefish" + expected = "organizations/{organization}".format( + organization=organization, + ) + actual = PredictionServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "mussel", + } + path = PredictionServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "winkle" + expected = "projects/{project}".format( + project=project, + ) + actual = PredictionServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "nautilus", + } + path = PredictionServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "scallop" + location = "abalone" + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + actual = PredictionServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "squid", + "location": "clam", + } + path = PredictionServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = PredictionServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.PredictionServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.PredictionServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = PredictionServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + +def test_get_operation(transport: str = "grpc"): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + response = client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +@pytest.mark.asyncio +async def test_get_operation_async(transport: str = "grpc_asyncio"): + client = PredictionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_get_operation_field_headers(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = operations_pb2.Operation() + + client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_operation_field_headers_async(): + client = PredictionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_get_operation_from_dict(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + + response = client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_operation_from_dict_async(): + client = PredictionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_list_operations(transport: str = "grpc"): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + response = client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +@pytest.mark.asyncio +async def test_list_operations_async(transport: str = "grpc_asyncio"): + client = PredictionServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_list_operations_field_headers(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = operations_pb2.ListOperationsResponse() + + client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_operations_field_headers_async(): + client = PredictionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_list_operations_from_dict(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + + response = client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_list_operations_from_dict_async(): + client = PredictionServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_transport_close_grpc(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +@pytest.mark.asyncio +async def test_transport_close_grpc_asyncio(): + client = PredictionServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close_rest(): + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + with mock.patch.object( + type(getattr(client.transport, "_session")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "rest", + "grpc", + ] + for transport in transports: + client = PredictionServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (PredictionServiceClient, transports.PredictionServiceGrpcTransport), + ( + PredictionServiceAsyncClient, + transports.PredictionServiceGrpcAsyncIOTransport, + ), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_retriever_service.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_retriever_service.py new file mode 100644 index 000000000000..2a80e776b391 --- /dev/null +++ b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_retriever_service.py @@ -0,0 +1,16486 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +from collections.abc import AsyncIterable, Iterable +import json +import math + +from google.api_core import api_core_version +from google.protobuf import json_format +import grpc +from grpc.experimental import aio +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule, TimestampRule +import pytest +from requests import PreparedRequest, Request, Response +from requests.sessions import Session + +try: + from google.auth.aio import credentials as ga_credentials_async + + HAS_GOOGLE_AUTH_AIO = True +except ImportError: # pragma: NO COVER + HAS_GOOGLE_AUTH_AIO = False + +from google.api_core import gapic_v1, grpc_helpers, grpc_helpers_async, path_template +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +import google.auth +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account +from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import timestamp_pb2 # type: ignore + +from google.ai.generativelanguage_v1alpha.services.retriever_service import ( + RetrieverServiceAsyncClient, + RetrieverServiceClient, + pagers, + transports, +) +from google.ai.generativelanguage_v1alpha.types import retriever, retriever_service + + +async def mock_async_gen(data, chunk_size=1): + for i in range(0, len(data)): # pragma: NO COVER + chunk = data[i : i + chunk_size] + yield chunk.encode("utf-8") + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# TODO: use async auth anon credentials by default once the minimum version of google-auth is upgraded. +# See related issue: https://github.com/googleapis/gapic-generator-python/issues/2107. +def async_anonymous_credentials(): + if HAS_GOOGLE_AUTH_AIO: + return ga_credentials_async.AnonymousCredentials() + return ga_credentials.AnonymousCredentials() + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +# If default endpoint template is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint template so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint_template(client): + return ( + "test.{UNIVERSE_DOMAIN}" + if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE) + else client._DEFAULT_ENDPOINT_TEMPLATE + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert RetrieverServiceClient._get_default_mtls_endpoint(None) is None + assert ( + RetrieverServiceClient._get_default_mtls_endpoint(api_endpoint) + == api_mtls_endpoint + ) + assert ( + RetrieverServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + RetrieverServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + RetrieverServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + RetrieverServiceClient._get_default_mtls_endpoint(non_googleapi) + == non_googleapi + ) + + +def test__read_environment_variables(): + assert RetrieverServiceClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert RetrieverServiceClient._read_environment_variables() == ( + True, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + assert RetrieverServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + RetrieverServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + assert RetrieverServiceClient._read_environment_variables() == ( + False, + "never", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + assert RetrieverServiceClient._read_environment_variables() == ( + False, + "always", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + assert RetrieverServiceClient._read_environment_variables() == ( + False, + "auto", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + RetrieverServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}): + assert RetrieverServiceClient._read_environment_variables() == ( + False, + "auto", + "foo.com", + ) + + +def test__get_client_cert_source(): + mock_provided_cert_source = mock.Mock() + mock_default_cert_source = mock.Mock() + + assert RetrieverServiceClient._get_client_cert_source(None, False) is None + assert ( + RetrieverServiceClient._get_client_cert_source(mock_provided_cert_source, False) + is None + ) + assert ( + RetrieverServiceClient._get_client_cert_source(mock_provided_cert_source, True) + == mock_provided_cert_source + ) + + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", return_value=True + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_default_cert_source, + ): + assert ( + RetrieverServiceClient._get_client_cert_source(None, True) + is mock_default_cert_source + ) + assert ( + RetrieverServiceClient._get_client_cert_source( + mock_provided_cert_source, "true" + ) + is mock_provided_cert_source + ) + + +@mock.patch.object( + RetrieverServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(RetrieverServiceClient), +) +@mock.patch.object( + RetrieverServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(RetrieverServiceAsyncClient), +) +def test__get_api_endpoint(): + api_override = "foo.com" + mock_client_cert_source = mock.Mock() + default_universe = RetrieverServiceClient._DEFAULT_UNIVERSE + default_endpoint = RetrieverServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = RetrieverServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + assert ( + RetrieverServiceClient._get_api_endpoint( + api_override, mock_client_cert_source, default_universe, "always" + ) + == api_override + ) + assert ( + RetrieverServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "auto" + ) + == RetrieverServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + RetrieverServiceClient._get_api_endpoint(None, None, default_universe, "auto") + == default_endpoint + ) + assert ( + RetrieverServiceClient._get_api_endpoint(None, None, default_universe, "always") + == RetrieverServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + RetrieverServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "always" + ) + == RetrieverServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + RetrieverServiceClient._get_api_endpoint(None, None, mock_universe, "never") + == mock_endpoint + ) + assert ( + RetrieverServiceClient._get_api_endpoint(None, None, default_universe, "never") + == default_endpoint + ) + + with pytest.raises(MutualTLSChannelError) as excinfo: + RetrieverServiceClient._get_api_endpoint( + None, mock_client_cert_source, mock_universe, "auto" + ) + assert ( + str(excinfo.value) + == "mTLS is not supported in any universe other than googleapis.com." + ) + + +def test__get_universe_domain(): + client_universe_domain = "foo.com" + universe_domain_env = "bar.com" + + assert ( + RetrieverServiceClient._get_universe_domain( + client_universe_domain, universe_domain_env + ) + == client_universe_domain + ) + assert ( + RetrieverServiceClient._get_universe_domain(None, universe_domain_env) + == universe_domain_env + ) + assert ( + RetrieverServiceClient._get_universe_domain(None, None) + == RetrieverServiceClient._DEFAULT_UNIVERSE + ) + + with pytest.raises(ValueError) as excinfo: + RetrieverServiceClient._get_universe_domain("", None) + assert str(excinfo.value) == "Universe Domain cannot be an empty string." + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (RetrieverServiceClient, "grpc"), + (RetrieverServiceAsyncClient, "grpc_asyncio"), + (RetrieverServiceClient, "rest"), + ], +) +def test_retriever_service_client_from_service_account_info( + client_class, transport_name +): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.RetrieverServiceGrpcTransport, "grpc"), + (transports.RetrieverServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.RetrieverServiceRestTransport, "rest"), + ], +) +def test_retriever_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (RetrieverServiceClient, "grpc"), + (RetrieverServiceAsyncClient, "grpc_asyncio"), + (RetrieverServiceClient, "rest"), + ], +) +def test_retriever_service_client_from_service_account_file( + client_class, transport_name +): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +def test_retriever_service_client_get_transport_class(): + transport = RetrieverServiceClient.get_transport_class() + available_transports = [ + transports.RetrieverServiceGrpcTransport, + transports.RetrieverServiceRestTransport, + ] + assert transport in available_transports + + transport = RetrieverServiceClient.get_transport_class("grpc") + assert transport == transports.RetrieverServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (RetrieverServiceClient, transports.RetrieverServiceGrpcTransport, "grpc"), + ( + RetrieverServiceAsyncClient, + transports.RetrieverServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (RetrieverServiceClient, transports.RetrieverServiceRestTransport, "rest"), + ], +) +@mock.patch.object( + RetrieverServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(RetrieverServiceClient), +) +@mock.patch.object( + RetrieverServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(RetrieverServiceAsyncClient), +) +def test_retriever_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(RetrieverServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(RetrieverServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + ( + RetrieverServiceClient, + transports.RetrieverServiceGrpcTransport, + "grpc", + "true", + ), + ( + RetrieverServiceAsyncClient, + transports.RetrieverServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + ( + RetrieverServiceClient, + transports.RetrieverServiceGrpcTransport, + "grpc", + "false", + ), + ( + RetrieverServiceAsyncClient, + transports.RetrieverServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + ( + RetrieverServiceClient, + transports.RetrieverServiceRestTransport, + "rest", + "true", + ), + ( + RetrieverServiceClient, + transports.RetrieverServiceRestTransport, + "rest", + "false", + ), + ], +) +@mock.patch.object( + RetrieverServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(RetrieverServiceClient), +) +@mock.patch.object( + RetrieverServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(RetrieverServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_retriever_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class", [RetrieverServiceClient, RetrieverServiceAsyncClient] +) +@mock.patch.object( + RetrieverServiceClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(RetrieverServiceClient), +) +@mock.patch.object( + RetrieverServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(RetrieverServiceAsyncClient), +) +def test_retriever_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + +@pytest.mark.parametrize( + "client_class", [RetrieverServiceClient, RetrieverServiceAsyncClient] +) +@mock.patch.object( + RetrieverServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(RetrieverServiceClient), +) +@mock.patch.object( + RetrieverServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(RetrieverServiceAsyncClient), +) +def test_retriever_service_client_client_api_endpoint(client_class): + mock_client_cert_source = client_cert_source_callback + api_override = "foo.com" + default_universe = RetrieverServiceClient._DEFAULT_UNIVERSE + default_endpoint = RetrieverServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = RetrieverServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true", + # use ClientOptions.api_endpoint as the api endpoint regardless. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ): + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=api_override + ) + client = client_class( + client_options=options, + credentials=ga_credentials.AnonymousCredentials(), + ) + assert client.api_endpoint == api_override + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == default_endpoint + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always", + # use the DEFAULT_MTLS_ENDPOINT as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + + # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default), + # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist, + # and ClientOptions.universe_domain="bar.com", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint. + options = client_options.ClientOptions() + universe_exists = hasattr(options, "universe_domain") + if universe_exists: + options = client_options.ClientOptions(universe_domain=mock_universe) + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + else: + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == ( + mock_endpoint if universe_exists else default_endpoint + ) + assert client.universe_domain == ( + mock_universe if universe_exists else default_universe + ) + + # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + options = client_options.ClientOptions() + if hasattr(options, "universe_domain"): + delattr(options, "universe_domain") + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == default_endpoint + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (RetrieverServiceClient, transports.RetrieverServiceGrpcTransport, "grpc"), + ( + RetrieverServiceAsyncClient, + transports.RetrieverServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (RetrieverServiceClient, transports.RetrieverServiceRestTransport, "rest"), + ], +) +def test_retriever_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + RetrieverServiceClient, + transports.RetrieverServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + RetrieverServiceAsyncClient, + transports.RetrieverServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ( + RetrieverServiceClient, + transports.RetrieverServiceRestTransport, + "rest", + None, + ), + ], +) +def test_retriever_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +def test_retriever_service_client_client_options_from_dict(): + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.retriever_service.transports.RetrieverServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = RetrieverServiceClient( + client_options={"api_endpoint": "squid.clam.whelk"} + ) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + ( + RetrieverServiceClient, + transports.RetrieverServiceGrpcTransport, + "grpc", + grpc_helpers, + ), + ( + RetrieverServiceAsyncClient, + transports.RetrieverServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_retriever_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=(), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.CreateCorpusRequest, + dict, + ], +) +def test_create_corpus(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Corpus( + name="name_value", + display_name="display_name_value", + ) + response = client.create_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.CreateCorpusRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Corpus) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +def test_create_corpus_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.CreateCorpusRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_corpus), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.create_corpus(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.CreateCorpusRequest() + + +def test_create_corpus_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_corpus] = mock_rpc + request = {} + client.create_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_corpus_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_corpus + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.create_corpus + ] = mock_rpc + + request = {} + await client.create_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.create_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_corpus_async( + transport: str = "grpc_asyncio", request_type=retriever_service.CreateCorpusRequest +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Corpus( + name="name_value", + display_name="display_name_value", + ) + ) + response = await client.create_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.CreateCorpusRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Corpus) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +@pytest.mark.asyncio +async def test_create_corpus_async_from_dict(): + await test_create_corpus_async(request_type=dict) + + +def test_create_corpus_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Corpus() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_corpus( + corpus=retriever.Corpus(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].corpus + mock_val = retriever.Corpus(name="name_value") + assert arg == mock_val + + +def test_create_corpus_flattened_error(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_corpus( + retriever_service.CreateCorpusRequest(), + corpus=retriever.Corpus(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_corpus_flattened_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Corpus() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Corpus()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_corpus( + corpus=retriever.Corpus(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].corpus + mock_val = retriever.Corpus(name="name_value") + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_create_corpus_flattened_error_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_corpus( + retriever_service.CreateCorpusRequest(), + corpus=retriever.Corpus(name="name_value"), + ) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.GetCorpusRequest, + dict, + ], +) +def test_get_corpus(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Corpus( + name="name_value", + display_name="display_name_value", + ) + response = client.get_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.GetCorpusRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Corpus) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +def test_get_corpus_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.GetCorpusRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_corpus), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.get_corpus(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.GetCorpusRequest( + name="name_value", + ) + + +def test_get_corpus_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_corpus] = mock_rpc + request = {} + client.get_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_corpus_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_corpus + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.get_corpus + ] = mock_rpc + + request = {} + await client.get_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.get_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_corpus_async( + transport: str = "grpc_asyncio", request_type=retriever_service.GetCorpusRequest +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Corpus( + name="name_value", + display_name="display_name_value", + ) + ) + response = await client.get_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.GetCorpusRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Corpus) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +@pytest.mark.asyncio +async def test_get_corpus_async_from_dict(): + await test_get_corpus_async(request_type=dict) + + +def test_get_corpus_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.GetCorpusRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_corpus), "__call__") as call: + call.return_value = retriever.Corpus() + client.get_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_corpus_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.GetCorpusRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_corpus), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Corpus()) + await client.get_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_get_corpus_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Corpus() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_corpus( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_get_corpus_flattened_error(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_corpus( + retriever_service.GetCorpusRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_corpus_flattened_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Corpus() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Corpus()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_corpus( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_get_corpus_flattened_error_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_corpus( + retriever_service.GetCorpusRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.UpdateCorpusRequest, + dict, + ], +) +def test_update_corpus(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Corpus( + name="name_value", + display_name="display_name_value", + ) + response = client.update_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.UpdateCorpusRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Corpus) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +def test_update_corpus_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.UpdateCorpusRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_corpus), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.update_corpus(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.UpdateCorpusRequest() + + +def test_update_corpus_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_corpus] = mock_rpc + request = {} + client.update_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_update_corpus_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_corpus + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.update_corpus + ] = mock_rpc + + request = {} + await client.update_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.update_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_update_corpus_async( + transport: str = "grpc_asyncio", request_type=retriever_service.UpdateCorpusRequest +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Corpus( + name="name_value", + display_name="display_name_value", + ) + ) + response = await client.update_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.UpdateCorpusRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Corpus) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +@pytest.mark.asyncio +async def test_update_corpus_async_from_dict(): + await test_update_corpus_async(request_type=dict) + + +def test_update_corpus_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.UpdateCorpusRequest() + + request.corpus.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_corpus), "__call__") as call: + call.return_value = retriever.Corpus() + client.update_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "corpus.name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_corpus_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.UpdateCorpusRequest() + + request.corpus.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_corpus), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Corpus()) + await client.update_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "corpus.name=name_value", + ) in kw["metadata"] + + +def test_update_corpus_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Corpus() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_corpus( + corpus=retriever.Corpus(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].corpus + mock_val = retriever.Corpus(name="name_value") + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +def test_update_corpus_flattened_error(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_corpus( + retriever_service.UpdateCorpusRequest(), + corpus=retriever.Corpus(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_corpus_flattened_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Corpus() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Corpus()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_corpus( + corpus=retriever.Corpus(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].corpus + mock_val = retriever.Corpus(name="name_value") + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_update_corpus_flattened_error_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_corpus( + retriever_service.UpdateCorpusRequest(), + corpus=retriever.Corpus(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.DeleteCorpusRequest, + dict, + ], +) +def test_delete_corpus(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.delete_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.DeleteCorpusRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_corpus_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.DeleteCorpusRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_corpus), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.delete_corpus(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.DeleteCorpusRequest( + name="name_value", + ) + + +def test_delete_corpus_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_corpus] = mock_rpc + request = {} + client.delete_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_corpus_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_corpus + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_corpus + ] = mock_rpc + + request = {} + await client.delete_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.delete_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_corpus_async( + transport: str = "grpc_asyncio", request_type=retriever_service.DeleteCorpusRequest +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.delete_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.DeleteCorpusRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_delete_corpus_async_from_dict(): + await test_delete_corpus_async(request_type=dict) + + +def test_delete_corpus_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.DeleteCorpusRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_corpus), "__call__") as call: + call.return_value = None + client.delete_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_corpus_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.DeleteCorpusRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_corpus), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_delete_corpus_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_corpus( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_delete_corpus_flattened_error(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_corpus( + retriever_service.DeleteCorpusRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_corpus_flattened_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_corpus( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_delete_corpus_flattened_error_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_corpus( + retriever_service.DeleteCorpusRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.ListCorporaRequest, + dict, + ], +) +def test_list_corpora(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_corpora), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever_service.ListCorporaResponse( + next_page_token="next_page_token_value", + ) + response = client.list_corpora(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.ListCorporaRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListCorporaPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_corpora_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.ListCorporaRequest( + page_token="page_token_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_corpora), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_corpora(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.ListCorporaRequest( + page_token="page_token_value", + ) + + +def test_list_corpora_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_corpora in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_corpora] = mock_rpc + request = {} + client.list_corpora(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_corpora(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_corpora_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_corpora + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.list_corpora + ] = mock_rpc + + request = {} + await client.list_corpora(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.list_corpora(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_corpora_async( + transport: str = "grpc_asyncio", request_type=retriever_service.ListCorporaRequest +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_corpora), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.ListCorporaResponse( + next_page_token="next_page_token_value", + ) + ) + response = await client.list_corpora(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.ListCorporaRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListCorporaAsyncPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_corpora_async_from_dict(): + await test_list_corpora_async(request_type=dict) + + +def test_list_corpora_pager(transport_name: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_corpora), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + retriever_service.ListCorporaResponse( + corpora=[ + retriever.Corpus(), + retriever.Corpus(), + retriever.Corpus(), + ], + next_page_token="abc", + ), + retriever_service.ListCorporaResponse( + corpora=[], + next_page_token="def", + ), + retriever_service.ListCorporaResponse( + corpora=[ + retriever.Corpus(), + ], + next_page_token="ghi", + ), + retriever_service.ListCorporaResponse( + corpora=[ + retriever.Corpus(), + retriever.Corpus(), + ], + ), + RuntimeError, + ) + + expected_metadata = () + retry = retries.Retry() + timeout = 5 + pager = client.list_corpora(request={}, retry=retry, timeout=timeout) + + assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, retriever.Corpus) for i in results) + + +def test_list_corpora_pages(transport_name: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_corpora), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + retriever_service.ListCorporaResponse( + corpora=[ + retriever.Corpus(), + retriever.Corpus(), + retriever.Corpus(), + ], + next_page_token="abc", + ), + retriever_service.ListCorporaResponse( + corpora=[], + next_page_token="def", + ), + retriever_service.ListCorporaResponse( + corpora=[ + retriever.Corpus(), + ], + next_page_token="ghi", + ), + retriever_service.ListCorporaResponse( + corpora=[ + retriever.Corpus(), + retriever.Corpus(), + ], + ), + RuntimeError, + ) + pages = list(client.list_corpora(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_corpora_async_pager(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_corpora), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + retriever_service.ListCorporaResponse( + corpora=[ + retriever.Corpus(), + retriever.Corpus(), + retriever.Corpus(), + ], + next_page_token="abc", + ), + retriever_service.ListCorporaResponse( + corpora=[], + next_page_token="def", + ), + retriever_service.ListCorporaResponse( + corpora=[ + retriever.Corpus(), + ], + next_page_token="ghi", + ), + retriever_service.ListCorporaResponse( + corpora=[ + retriever.Corpus(), + retriever.Corpus(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_corpora( + request={}, + ) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, retriever.Corpus) for i in responses) + + +@pytest.mark.asyncio +async def test_list_corpora_async_pages(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_corpora), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + retriever_service.ListCorporaResponse( + corpora=[ + retriever.Corpus(), + retriever.Corpus(), + retriever.Corpus(), + ], + next_page_token="abc", + ), + retriever_service.ListCorporaResponse( + corpora=[], + next_page_token="def", + ), + retriever_service.ListCorporaResponse( + corpora=[ + retriever.Corpus(), + ], + next_page_token="ghi", + ), + retriever_service.ListCorporaResponse( + corpora=[ + retriever.Corpus(), + retriever.Corpus(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_corpora(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.QueryCorpusRequest, + dict, + ], +) +def test_query_corpus(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.query_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever_service.QueryCorpusResponse() + response = client.query_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.QueryCorpusRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever_service.QueryCorpusResponse) + + +def test_query_corpus_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.QueryCorpusRequest( + name="name_value", + query="query_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.query_corpus), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.query_corpus(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.QueryCorpusRequest( + name="name_value", + query="query_value", + ) + + +def test_query_corpus_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.query_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.query_corpus] = mock_rpc + request = {} + client.query_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_query_corpus_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.query_corpus + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.query_corpus + ] = mock_rpc + + request = {} + await client.query_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.query_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_query_corpus_async( + transport: str = "grpc_asyncio", request_type=retriever_service.QueryCorpusRequest +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.query_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.QueryCorpusResponse() + ) + response = await client.query_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.QueryCorpusRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever_service.QueryCorpusResponse) + + +@pytest.mark.asyncio +async def test_query_corpus_async_from_dict(): + await test_query_corpus_async(request_type=dict) + + +def test_query_corpus_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.QueryCorpusRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.query_corpus), "__call__") as call: + call.return_value = retriever_service.QueryCorpusResponse() + client.query_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_query_corpus_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.QueryCorpusRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.query_corpus), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.QueryCorpusResponse() + ) + await client.query_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.CreateDocumentRequest, + dict, + ], +) +def test_create_document(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Document( + name="name_value", + display_name="display_name_value", + ) + response = client.create_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.CreateDocumentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Document) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +def test_create_document_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.CreateDocumentRequest( + parent="parent_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_document), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.create_document(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.CreateDocumentRequest( + parent="parent_value", + ) + + +def test_create_document_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_document in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_document] = mock_rpc + request = {} + client.create_document(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_document(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_document_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_document + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.create_document + ] = mock_rpc + + request = {} + await client.create_document(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.create_document(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_document_async( + transport: str = "grpc_asyncio", + request_type=retriever_service.CreateDocumentRequest, +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Document( + name="name_value", + display_name="display_name_value", + ) + ) + response = await client.create_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.CreateDocumentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Document) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +@pytest.mark.asyncio +async def test_create_document_async_from_dict(): + await test_create_document_async(request_type=dict) + + +def test_create_document_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.CreateDocumentRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_document), "__call__") as call: + call.return_value = retriever.Document() + client.create_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_document_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.CreateDocumentRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_document), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Document()) + await client.create_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +def test_create_document_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Document() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_document( + parent="parent_value", + document=retriever.Document(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + arg = args[0].document + mock_val = retriever.Document(name="name_value") + assert arg == mock_val + + +def test_create_document_flattened_error(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_document( + retriever_service.CreateDocumentRequest(), + parent="parent_value", + document=retriever.Document(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_document_flattened_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Document() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Document()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_document( + parent="parent_value", + document=retriever.Document(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + arg = args[0].document + mock_val = retriever.Document(name="name_value") + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_create_document_flattened_error_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_document( + retriever_service.CreateDocumentRequest(), + parent="parent_value", + document=retriever.Document(name="name_value"), + ) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.GetDocumentRequest, + dict, + ], +) +def test_get_document(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Document( + name="name_value", + display_name="display_name_value", + ) + response = client.get_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.GetDocumentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Document) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +def test_get_document_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.GetDocumentRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_document), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.get_document(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.GetDocumentRequest( + name="name_value", + ) + + +def test_get_document_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_document in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_document] = mock_rpc + request = {} + client.get_document(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_document(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_document_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_document + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.get_document + ] = mock_rpc + + request = {} + await client.get_document(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.get_document(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_document_async( + transport: str = "grpc_asyncio", request_type=retriever_service.GetDocumentRequest +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Document( + name="name_value", + display_name="display_name_value", + ) + ) + response = await client.get_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.GetDocumentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Document) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +@pytest.mark.asyncio +async def test_get_document_async_from_dict(): + await test_get_document_async(request_type=dict) + + +def test_get_document_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.GetDocumentRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_document), "__call__") as call: + call.return_value = retriever.Document() + client.get_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_document_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.GetDocumentRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_document), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Document()) + await client.get_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_get_document_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Document() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_document( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_get_document_flattened_error(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_document( + retriever_service.GetDocumentRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_document_flattened_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Document() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Document()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_document( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_get_document_flattened_error_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_document( + retriever_service.GetDocumentRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.UpdateDocumentRequest, + dict, + ], +) +def test_update_document(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Document( + name="name_value", + display_name="display_name_value", + ) + response = client.update_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.UpdateDocumentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Document) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +def test_update_document_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.UpdateDocumentRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_document), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.update_document(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.UpdateDocumentRequest() + + +def test_update_document_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_document in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_document] = mock_rpc + request = {} + client.update_document(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_document(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_update_document_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_document + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.update_document + ] = mock_rpc + + request = {} + await client.update_document(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.update_document(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_update_document_async( + transport: str = "grpc_asyncio", + request_type=retriever_service.UpdateDocumentRequest, +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Document( + name="name_value", + display_name="display_name_value", + ) + ) + response = await client.update_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.UpdateDocumentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Document) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +@pytest.mark.asyncio +async def test_update_document_async_from_dict(): + await test_update_document_async(request_type=dict) + + +def test_update_document_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.UpdateDocumentRequest() + + request.document.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_document), "__call__") as call: + call.return_value = retriever.Document() + client.update_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "document.name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_document_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.UpdateDocumentRequest() + + request.document.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_document), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Document()) + await client.update_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "document.name=name_value", + ) in kw["metadata"] + + +def test_update_document_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Document() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_document( + document=retriever.Document(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].document + mock_val = retriever.Document(name="name_value") + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +def test_update_document_flattened_error(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_document( + retriever_service.UpdateDocumentRequest(), + document=retriever.Document(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_document_flattened_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Document() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Document()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_document( + document=retriever.Document(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].document + mock_val = retriever.Document(name="name_value") + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_update_document_flattened_error_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_document( + retriever_service.UpdateDocumentRequest(), + document=retriever.Document(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.DeleteDocumentRequest, + dict, + ], +) +def test_delete_document(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.delete_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.DeleteDocumentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_document_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.DeleteDocumentRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_document), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.delete_document(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.DeleteDocumentRequest( + name="name_value", + ) + + +def test_delete_document_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_document in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_document] = mock_rpc + request = {} + client.delete_document(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_document(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_document_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_document + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_document + ] = mock_rpc + + request = {} + await client.delete_document(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.delete_document(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_document_async( + transport: str = "grpc_asyncio", + request_type=retriever_service.DeleteDocumentRequest, +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.delete_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.DeleteDocumentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_delete_document_async_from_dict(): + await test_delete_document_async(request_type=dict) + + +def test_delete_document_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.DeleteDocumentRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_document), "__call__") as call: + call.return_value = None + client.delete_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_document_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.DeleteDocumentRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_document), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_delete_document_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_document( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_delete_document_flattened_error(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_document( + retriever_service.DeleteDocumentRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_document_flattened_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_document( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_delete_document_flattened_error_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_document( + retriever_service.DeleteDocumentRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.ListDocumentsRequest, + dict, + ], +) +def test_list_documents(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_documents), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever_service.ListDocumentsResponse( + next_page_token="next_page_token_value", + ) + response = client.list_documents(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.ListDocumentsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDocumentsPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_documents_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.ListDocumentsRequest( + parent="parent_value", + page_token="page_token_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_documents), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_documents(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.ListDocumentsRequest( + parent="parent_value", + page_token="page_token_value", + ) + + +def test_list_documents_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_documents in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_documents] = mock_rpc + request = {} + client.list_documents(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_documents(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_documents_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_documents + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.list_documents + ] = mock_rpc + + request = {} + await client.list_documents(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.list_documents(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_documents_async( + transport: str = "grpc_asyncio", request_type=retriever_service.ListDocumentsRequest +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_documents), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.ListDocumentsResponse( + next_page_token="next_page_token_value", + ) + ) + response = await client.list_documents(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.ListDocumentsRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDocumentsAsyncPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_documents_async_from_dict(): + await test_list_documents_async(request_type=dict) + + +def test_list_documents_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.ListDocumentsRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_documents), "__call__") as call: + call.return_value = retriever_service.ListDocumentsResponse() + client.list_documents(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_documents_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.ListDocumentsRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_documents), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.ListDocumentsResponse() + ) + await client.list_documents(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +def test_list_documents_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_documents), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever_service.ListDocumentsResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_documents( + parent="parent_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + + +def test_list_documents_flattened_error(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_documents( + retriever_service.ListDocumentsRequest(), + parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_documents_flattened_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_documents), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever_service.ListDocumentsResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.ListDocumentsResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_documents( + parent="parent_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_list_documents_flattened_error_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_documents( + retriever_service.ListDocumentsRequest(), + parent="parent_value", + ) + + +def test_list_documents_pager(transport_name: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_documents), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + retriever_service.ListDocumentsResponse( + documents=[ + retriever.Document(), + retriever.Document(), + retriever.Document(), + ], + next_page_token="abc", + ), + retriever_service.ListDocumentsResponse( + documents=[], + next_page_token="def", + ), + retriever_service.ListDocumentsResponse( + documents=[ + retriever.Document(), + ], + next_page_token="ghi", + ), + retriever_service.ListDocumentsResponse( + documents=[ + retriever.Document(), + retriever.Document(), + ], + ), + RuntimeError, + ) + + expected_metadata = () + retry = retries.Retry() + timeout = 5 + expected_metadata = tuple(expected_metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_documents(request={}, retry=retry, timeout=timeout) + + assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, retriever.Document) for i in results) + + +def test_list_documents_pages(transport_name: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_documents), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + retriever_service.ListDocumentsResponse( + documents=[ + retriever.Document(), + retriever.Document(), + retriever.Document(), + ], + next_page_token="abc", + ), + retriever_service.ListDocumentsResponse( + documents=[], + next_page_token="def", + ), + retriever_service.ListDocumentsResponse( + documents=[ + retriever.Document(), + ], + next_page_token="ghi", + ), + retriever_service.ListDocumentsResponse( + documents=[ + retriever.Document(), + retriever.Document(), + ], + ), + RuntimeError, + ) + pages = list(client.list_documents(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_documents_async_pager(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_documents), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + retriever_service.ListDocumentsResponse( + documents=[ + retriever.Document(), + retriever.Document(), + retriever.Document(), + ], + next_page_token="abc", + ), + retriever_service.ListDocumentsResponse( + documents=[], + next_page_token="def", + ), + retriever_service.ListDocumentsResponse( + documents=[ + retriever.Document(), + ], + next_page_token="ghi", + ), + retriever_service.ListDocumentsResponse( + documents=[ + retriever.Document(), + retriever.Document(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_documents( + request={}, + ) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, retriever.Document) for i in responses) + + +@pytest.mark.asyncio +async def test_list_documents_async_pages(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_documents), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + retriever_service.ListDocumentsResponse( + documents=[ + retriever.Document(), + retriever.Document(), + retriever.Document(), + ], + next_page_token="abc", + ), + retriever_service.ListDocumentsResponse( + documents=[], + next_page_token="def", + ), + retriever_service.ListDocumentsResponse( + documents=[ + retriever.Document(), + ], + next_page_token="ghi", + ), + retriever_service.ListDocumentsResponse( + documents=[ + retriever.Document(), + retriever.Document(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_documents(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.QueryDocumentRequest, + dict, + ], +) +def test_query_document(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.query_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever_service.QueryDocumentResponse() + response = client.query_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.QueryDocumentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever_service.QueryDocumentResponse) + + +def test_query_document_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.QueryDocumentRequest( + name="name_value", + query="query_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.query_document), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.query_document(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.QueryDocumentRequest( + name="name_value", + query="query_value", + ) + + +def test_query_document_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.query_document in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.query_document] = mock_rpc + request = {} + client.query_document(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_document(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_query_document_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.query_document + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.query_document + ] = mock_rpc + + request = {} + await client.query_document(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.query_document(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_query_document_async( + transport: str = "grpc_asyncio", request_type=retriever_service.QueryDocumentRequest +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.query_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.QueryDocumentResponse() + ) + response = await client.query_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.QueryDocumentRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever_service.QueryDocumentResponse) + + +@pytest.mark.asyncio +async def test_query_document_async_from_dict(): + await test_query_document_async(request_type=dict) + + +def test_query_document_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.QueryDocumentRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.query_document), "__call__") as call: + call.return_value = retriever_service.QueryDocumentResponse() + client.query_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_query_document_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.QueryDocumentRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.query_document), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.QueryDocumentResponse() + ) + await client.query_document(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.CreateChunkRequest, + dict, + ], +) +def test_create_chunk(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Chunk( + name="name_value", + state=retriever.Chunk.State.STATE_PENDING_PROCESSING, + ) + response = client.create_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.CreateChunkRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Chunk) + assert response.name == "name_value" + assert response.state == retriever.Chunk.State.STATE_PENDING_PROCESSING + + +def test_create_chunk_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.CreateChunkRequest( + parent="parent_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_chunk), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.create_chunk(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.CreateChunkRequest( + parent="parent_value", + ) + + +def test_create_chunk_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_chunk in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_chunk] = mock_rpc + request = {} + client.create_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_chunk(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_chunk_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.create_chunk + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.create_chunk + ] = mock_rpc + + request = {} + await client.create_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.create_chunk(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_create_chunk_async( + transport: str = "grpc_asyncio", request_type=retriever_service.CreateChunkRequest +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Chunk( + name="name_value", + state=retriever.Chunk.State.STATE_PENDING_PROCESSING, + ) + ) + response = await client.create_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.CreateChunkRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Chunk) + assert response.name == "name_value" + assert response.state == retriever.Chunk.State.STATE_PENDING_PROCESSING + + +@pytest.mark.asyncio +async def test_create_chunk_async_from_dict(): + await test_create_chunk_async(request_type=dict) + + +def test_create_chunk_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.CreateChunkRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_chunk), "__call__") as call: + call.return_value = retriever.Chunk() + client.create_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_create_chunk_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.CreateChunkRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_chunk), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Chunk()) + await client.create_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +def test_create_chunk_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Chunk() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.create_chunk( + parent="parent_value", + chunk=retriever.Chunk(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + arg = args[0].chunk + mock_val = retriever.Chunk(name="name_value") + assert arg == mock_val + + +def test_create_chunk_flattened_error(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_chunk( + retriever_service.CreateChunkRequest(), + parent="parent_value", + chunk=retriever.Chunk(name="name_value"), + ) + + +@pytest.mark.asyncio +async def test_create_chunk_flattened_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.create_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Chunk() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Chunk()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.create_chunk( + parent="parent_value", + chunk=retriever.Chunk(name="name_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + arg = args[0].chunk + mock_val = retriever.Chunk(name="name_value") + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_create_chunk_flattened_error_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.create_chunk( + retriever_service.CreateChunkRequest(), + parent="parent_value", + chunk=retriever.Chunk(name="name_value"), + ) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.BatchCreateChunksRequest, + dict, + ], +) +def test_batch_create_chunks(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_chunks), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = retriever_service.BatchCreateChunksResponse() + response = client.batch_create_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.BatchCreateChunksRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever_service.BatchCreateChunksResponse) + + +def test_batch_create_chunks_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.BatchCreateChunksRequest( + parent="parent_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_chunks), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.batch_create_chunks(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.BatchCreateChunksRequest( + parent="parent_value", + ) + + +def test_batch_create_chunks_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_create_chunks in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_create_chunks + ] = mock_rpc + request = {} + client.batch_create_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_create_chunks(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_batch_create_chunks_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_create_chunks + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_create_chunks + ] = mock_rpc + + request = {} + await client.batch_create_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.batch_create_chunks(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_batch_create_chunks_async( + transport: str = "grpc_asyncio", + request_type=retriever_service.BatchCreateChunksRequest, +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_chunks), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.BatchCreateChunksResponse() + ) + response = await client.batch_create_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.BatchCreateChunksRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever_service.BatchCreateChunksResponse) + + +@pytest.mark.asyncio +async def test_batch_create_chunks_async_from_dict(): + await test_batch_create_chunks_async(request_type=dict) + + +def test_batch_create_chunks_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.BatchCreateChunksRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_chunks), "__call__" + ) as call: + call.return_value = retriever_service.BatchCreateChunksResponse() + client.batch_create_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_batch_create_chunks_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.BatchCreateChunksRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_chunks), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.BatchCreateChunksResponse() + ) + await client.batch_create_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.GetChunkRequest, + dict, + ], +) +def test_get_chunk(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Chunk( + name="name_value", + state=retriever.Chunk.State.STATE_PENDING_PROCESSING, + ) + response = client.get_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.GetChunkRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Chunk) + assert response.name == "name_value" + assert response.state == retriever.Chunk.State.STATE_PENDING_PROCESSING + + +def test_get_chunk_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.GetChunkRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_chunk), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.get_chunk(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.GetChunkRequest( + name="name_value", + ) + + +def test_get_chunk_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_chunk in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_chunk] = mock_rpc + request = {} + client.get_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_chunk(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_chunk_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.get_chunk + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.get_chunk + ] = mock_rpc + + request = {} + await client.get_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.get_chunk(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_chunk_async( + transport: str = "grpc_asyncio", request_type=retriever_service.GetChunkRequest +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Chunk( + name="name_value", + state=retriever.Chunk.State.STATE_PENDING_PROCESSING, + ) + ) + response = await client.get_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.GetChunkRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Chunk) + assert response.name == "name_value" + assert response.state == retriever.Chunk.State.STATE_PENDING_PROCESSING + + +@pytest.mark.asyncio +async def test_get_chunk_async_from_dict(): + await test_get_chunk_async(request_type=dict) + + +def test_get_chunk_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.GetChunkRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_chunk), "__call__") as call: + call.return_value = retriever.Chunk() + client.get_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_chunk_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.GetChunkRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_chunk), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Chunk()) + await client.get_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_get_chunk_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Chunk() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.get_chunk( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_get_chunk_flattened_error(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_chunk( + retriever_service.GetChunkRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_get_chunk_flattened_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Chunk() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Chunk()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.get_chunk( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_get_chunk_flattened_error_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.get_chunk( + retriever_service.GetChunkRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.UpdateChunkRequest, + dict, + ], +) +def test_update_chunk(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Chunk( + name="name_value", + state=retriever.Chunk.State.STATE_PENDING_PROCESSING, + ) + response = client.update_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.UpdateChunkRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Chunk) + assert response.name == "name_value" + assert response.state == retriever.Chunk.State.STATE_PENDING_PROCESSING + + +def test_update_chunk_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.UpdateChunkRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_chunk), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.update_chunk(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.UpdateChunkRequest() + + +def test_update_chunk_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_chunk in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_chunk] = mock_rpc + request = {} + client.update_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_chunk(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_update_chunk_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.update_chunk + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.update_chunk + ] = mock_rpc + + request = {} + await client.update_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.update_chunk(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_update_chunk_async( + transport: str = "grpc_asyncio", request_type=retriever_service.UpdateChunkRequest +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Chunk( + name="name_value", + state=retriever.Chunk.State.STATE_PENDING_PROCESSING, + ) + ) + response = await client.update_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.UpdateChunkRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Chunk) + assert response.name == "name_value" + assert response.state == retriever.Chunk.State.STATE_PENDING_PROCESSING + + +@pytest.mark.asyncio +async def test_update_chunk_async_from_dict(): + await test_update_chunk_async(request_type=dict) + + +def test_update_chunk_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.UpdateChunkRequest() + + request.chunk.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_chunk), "__call__") as call: + call.return_value = retriever.Chunk() + client.update_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "chunk.name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_update_chunk_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.UpdateChunkRequest() + + request.chunk.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_chunk), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Chunk()) + await client.update_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "chunk.name=name_value", + ) in kw["metadata"] + + +def test_update_chunk_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Chunk() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.update_chunk( + chunk=retriever.Chunk(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].chunk + mock_val = retriever.Chunk(name="name_value") + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +def test_update_chunk_flattened_error(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_chunk( + retriever_service.UpdateChunkRequest(), + chunk=retriever.Chunk(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.asyncio +async def test_update_chunk_flattened_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.update_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever.Chunk() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(retriever.Chunk()) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.update_chunk( + chunk=retriever.Chunk(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].chunk + mock_val = retriever.Chunk(name="name_value") + assert arg == mock_val + arg = args[0].update_mask + mock_val = field_mask_pb2.FieldMask(paths=["paths_value"]) + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_update_chunk_flattened_error_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.update_chunk( + retriever_service.UpdateChunkRequest(), + chunk=retriever.Chunk(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.BatchUpdateChunksRequest, + dict, + ], +) +def test_batch_update_chunks(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_update_chunks), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = retriever_service.BatchUpdateChunksResponse() + response = client.batch_update_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.BatchUpdateChunksRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever_service.BatchUpdateChunksResponse) + + +def test_batch_update_chunks_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.BatchUpdateChunksRequest( + parent="parent_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_update_chunks), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.batch_update_chunks(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.BatchUpdateChunksRequest( + parent="parent_value", + ) + + +def test_batch_update_chunks_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_update_chunks in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_update_chunks + ] = mock_rpc + request = {} + client.batch_update_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_update_chunks(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_batch_update_chunks_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_update_chunks + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_update_chunks + ] = mock_rpc + + request = {} + await client.batch_update_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.batch_update_chunks(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_batch_update_chunks_async( + transport: str = "grpc_asyncio", + request_type=retriever_service.BatchUpdateChunksRequest, +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_update_chunks), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.BatchUpdateChunksResponse() + ) + response = await client.batch_update_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.BatchUpdateChunksRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever_service.BatchUpdateChunksResponse) + + +@pytest.mark.asyncio +async def test_batch_update_chunks_async_from_dict(): + await test_batch_update_chunks_async(request_type=dict) + + +def test_batch_update_chunks_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.BatchUpdateChunksRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_update_chunks), "__call__" + ) as call: + call.return_value = retriever_service.BatchUpdateChunksResponse() + client.batch_update_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_batch_update_chunks_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.BatchUpdateChunksRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_update_chunks), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.BatchUpdateChunksResponse() + ) + await client.batch_update_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.DeleteChunkRequest, + dict, + ], +) +def test_delete_chunk(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.delete_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.DeleteChunkRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_delete_chunk_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.DeleteChunkRequest( + name="name_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_chunk), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.delete_chunk(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.DeleteChunkRequest( + name="name_value", + ) + + +def test_delete_chunk_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_chunk in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_chunk] = mock_rpc + request = {} + client.delete_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_chunk(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_chunk_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.delete_chunk + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.delete_chunk + ] = mock_rpc + + request = {} + await client.delete_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.delete_chunk(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_delete_chunk_async( + transport: str = "grpc_asyncio", request_type=retriever_service.DeleteChunkRequest +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.delete_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.DeleteChunkRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_delete_chunk_async_from_dict(): + await test_delete_chunk_async(request_type=dict) + + +def test_delete_chunk_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.DeleteChunkRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_chunk), "__call__") as call: + call.return_value = None + client.delete_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_delete_chunk_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.DeleteChunkRequest() + + request.name = "name_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_chunk), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=name_value", + ) in kw["metadata"] + + +def test_delete_chunk_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.delete_chunk( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +def test_delete_chunk_flattened_error(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_chunk( + retriever_service.DeleteChunkRequest(), + name="name_value", + ) + + +@pytest.mark.asyncio +async def test_delete_chunk_flattened_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.delete_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = None + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.delete_chunk( + name="name_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].name + mock_val = "name_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_delete_chunk_flattened_error_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.delete_chunk( + retriever_service.DeleteChunkRequest(), + name="name_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.BatchDeleteChunksRequest, + dict, + ], +) +def test_batch_delete_chunks(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_delete_chunks), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = None + response = client.batch_delete_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.BatchDeleteChunksRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +def test_batch_delete_chunks_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.BatchDeleteChunksRequest( + parent="parent_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_delete_chunks), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.batch_delete_chunks(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.BatchDeleteChunksRequest( + parent="parent_value", + ) + + +def test_batch_delete_chunks_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_delete_chunks in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_delete_chunks + ] = mock_rpc + request = {} + client.batch_delete_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_delete_chunks(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_batch_delete_chunks_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_delete_chunks + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_delete_chunks + ] = mock_rpc + + request = {} + await client.batch_delete_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.batch_delete_chunks(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_batch_delete_chunks_async( + transport: str = "grpc_asyncio", + request_type=retriever_service.BatchDeleteChunksRequest, +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_delete_chunks), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + response = await client.batch_delete_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.BatchDeleteChunksRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.asyncio +async def test_batch_delete_chunks_async_from_dict(): + await test_batch_delete_chunks_async(request_type=dict) + + +def test_batch_delete_chunks_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.BatchDeleteChunksRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_delete_chunks), "__call__" + ) as call: + call.return_value = None + client.batch_delete_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_batch_delete_chunks_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.BatchDeleteChunksRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.batch_delete_chunks), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.batch_delete_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.ListChunksRequest, + dict, + ], +) +def test_list_chunks(request_type, transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_chunks), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever_service.ListChunksResponse( + next_page_token="next_page_token_value", + ) + response = client.list_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = retriever_service.ListChunksRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListChunksPager) + assert response.next_page_token == "next_page_token_value" + + +def test_list_chunks_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = retriever_service.ListChunksRequest( + parent="parent_value", + page_token="page_token_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_chunks), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.list_chunks(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == retriever_service.ListChunksRequest( + parent="parent_value", + page_token="page_token_value", + ) + + +def test_list_chunks_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_chunks in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_chunks] = mock_rpc + request = {} + client.list_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_chunks(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_chunks_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.list_chunks + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.list_chunks + ] = mock_rpc + + request = {} + await client.list_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.list_chunks(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_list_chunks_async( + transport: str = "grpc_asyncio", request_type=retriever_service.ListChunksRequest +): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_chunks), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.ListChunksResponse( + next_page_token="next_page_token_value", + ) + ) + response = await client.list_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = retriever_service.ListChunksRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListChunksAsyncPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.asyncio +async def test_list_chunks_async_from_dict(): + await test_list_chunks_async(request_type=dict) + + +def test_list_chunks_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.ListChunksRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_chunks), "__call__") as call: + call.return_value = retriever_service.ListChunksResponse() + client.list_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_chunks_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = retriever_service.ListChunksRequest() + + request.parent = "parent_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_chunks), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.ListChunksResponse() + ) + await client.list_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "parent=parent_value", + ) in kw["metadata"] + + +def test_list_chunks_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_chunks), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever_service.ListChunksResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.list_chunks( + parent="parent_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + + +def test_list_chunks_flattened_error(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_chunks( + retriever_service.ListChunksRequest(), + parent="parent_value", + ) + + +@pytest.mark.asyncio +async def test_list_chunks_flattened_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_chunks), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = retriever_service.ListChunksResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.ListChunksResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.list_chunks( + parent="parent_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].parent + mock_val = "parent_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_list_chunks_flattened_error_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.list_chunks( + retriever_service.ListChunksRequest(), + parent="parent_value", + ) + + +def test_list_chunks_pager(transport_name: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_chunks), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + retriever_service.ListChunksResponse( + chunks=[ + retriever.Chunk(), + retriever.Chunk(), + retriever.Chunk(), + ], + next_page_token="abc", + ), + retriever_service.ListChunksResponse( + chunks=[], + next_page_token="def", + ), + retriever_service.ListChunksResponse( + chunks=[ + retriever.Chunk(), + ], + next_page_token="ghi", + ), + retriever_service.ListChunksResponse( + chunks=[ + retriever.Chunk(), + retriever.Chunk(), + ], + ), + RuntimeError, + ) + + expected_metadata = () + retry = retries.Retry() + timeout = 5 + expected_metadata = tuple(expected_metadata) + ( + gapic_v1.routing_header.to_grpc_metadata((("parent", ""),)), + ) + pager = client.list_chunks(request={}, retry=retry, timeout=timeout) + + assert pager._metadata == expected_metadata + assert pager._retry == retry + assert pager._timeout == timeout + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, retriever.Chunk) for i in results) + + +def test_list_chunks_pages(transport_name: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport_name, + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_chunks), "__call__") as call: + # Set the response to a series of pages. + call.side_effect = ( + retriever_service.ListChunksResponse( + chunks=[ + retriever.Chunk(), + retriever.Chunk(), + retriever.Chunk(), + ], + next_page_token="abc", + ), + retriever_service.ListChunksResponse( + chunks=[], + next_page_token="def", + ), + retriever_service.ListChunksResponse( + chunks=[ + retriever.Chunk(), + ], + next_page_token="ghi", + ), + retriever_service.ListChunksResponse( + chunks=[ + retriever.Chunk(), + retriever.Chunk(), + ], + ), + RuntimeError, + ) + pages = list(client.list_chunks(request={}).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +@pytest.mark.asyncio +async def test_list_chunks_async_pager(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_chunks), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + retriever_service.ListChunksResponse( + chunks=[ + retriever.Chunk(), + retriever.Chunk(), + retriever.Chunk(), + ], + next_page_token="abc", + ), + retriever_service.ListChunksResponse( + chunks=[], + next_page_token="def", + ), + retriever_service.ListChunksResponse( + chunks=[ + retriever.Chunk(), + ], + next_page_token="ghi", + ), + retriever_service.ListChunksResponse( + chunks=[ + retriever.Chunk(), + retriever.Chunk(), + ], + ), + RuntimeError, + ) + async_pager = await client.list_chunks( + request={}, + ) + assert async_pager.next_page_token == "abc" + responses = [] + async for response in async_pager: # pragma: no branch + responses.append(response) + + assert len(responses) == 6 + assert all(isinstance(i, retriever.Chunk) for i in responses) + + +@pytest.mark.asyncio +async def test_list_chunks_async_pages(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.list_chunks), "__call__", new_callable=mock.AsyncMock + ) as call: + # Set the response to a series of pages. + call.side_effect = ( + retriever_service.ListChunksResponse( + chunks=[ + retriever.Chunk(), + retriever.Chunk(), + retriever.Chunk(), + ], + next_page_token="abc", + ), + retriever_service.ListChunksResponse( + chunks=[], + next_page_token="def", + ), + retriever_service.ListChunksResponse( + chunks=[ + retriever.Chunk(), + ], + next_page_token="ghi", + ), + retriever_service.ListChunksResponse( + chunks=[ + retriever.Chunk(), + retriever.Chunk(), + ], + ), + RuntimeError, + ) + pages = [] + # Workaround issue in python 3.9 related to code coverage by adding `# pragma: no branch` + # See https://github.com/googleapis/gapic-generator-python/pull/1174#issuecomment-1025132372 + async for page_ in ( # pragma: no branch + await client.list_chunks(request={}) + ).pages: + pages.append(page_) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_create_corpus_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_corpus] = mock_rpc + + request = {} + client.create_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_create_corpus_rest_required_fields( + request_type=retriever_service.CreateCorpusRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_corpus._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_corpus._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = retriever.Corpus() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Corpus.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.create_corpus(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_corpus_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.create_corpus._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("corpus",))) + + +def test_create_corpus_rest_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Corpus() + + # get arguments that satisfy an http rule for this method + sample_request = {} + + # get truthy value for each flattened field + mock_args = dict( + corpus=retriever.Corpus(name="name_value"), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = retriever.Corpus.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.create_corpus(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/corpora" % client.transport._host, args[1] + ) + + +def test_create_corpus_rest_flattened_error(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_corpus( + retriever_service.CreateCorpusRequest(), + corpus=retriever.Corpus(name="name_value"), + ) + + +def test_get_corpus_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_corpus] = mock_rpc + + request = {} + client.get_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_get_corpus_rest_required_fields( + request_type=retriever_service.GetCorpusRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_corpus._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_corpus._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = retriever.Corpus() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Corpus.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_corpus(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_corpus_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_corpus._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +def test_get_corpus_rest_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Corpus() + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "corpora/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = retriever.Corpus.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.get_corpus(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{name=corpora/*}" % client.transport._host, args[1] + ) + + +def test_get_corpus_rest_flattened_error(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_corpus( + retriever_service.GetCorpusRequest(), + name="name_value", + ) + + +def test_update_corpus_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_corpus] = mock_rpc + + request = {} + client.update_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_update_corpus_rest_required_fields( + request_type=retriever_service.UpdateCorpusRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_corpus._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_corpus._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("update_mask",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = retriever.Corpus() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "patch", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Corpus.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.update_corpus(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_update_corpus_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.update_corpus._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(("updateMask",)) + & set( + ( + "corpus", + "updateMask", + ) + ) + ) + + +def test_update_corpus_rest_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Corpus() + + # get arguments that satisfy an http rule for this method + sample_request = {"corpus": {"name": "corpora/sample1"}} + + # get truthy value for each flattened field + mock_args = dict( + corpus=retriever.Corpus(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = retriever.Corpus.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.update_corpus(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{corpus.name=corpora/*}" % client.transport._host, args[1] + ) + + +def test_update_corpus_rest_flattened_error(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_corpus( + retriever_service.UpdateCorpusRequest(), + corpus=retriever.Corpus(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +def test_delete_corpus_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_corpus] = mock_rpc + + request = {} + client.delete_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_delete_corpus_rest_required_fields( + request_type=retriever_service.DeleteCorpusRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_corpus._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_corpus._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("force",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.delete_corpus(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_delete_corpus_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.delete_corpus._get_unset_required_fields({}) + assert set(unset_fields) == (set(("force",)) & set(("name",))) + + +def test_delete_corpus_rest_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "corpora/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.delete_corpus(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{name=corpora/*}" % client.transport._host, args[1] + ) + + +def test_delete_corpus_rest_flattened_error(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_corpus( + retriever_service.DeleteCorpusRequest(), + name="name_value", + ) + + +def test_list_corpora_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_corpora in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_corpora] = mock_rpc + + request = {} + client.list_corpora(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_corpora(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_list_corpora_rest_pager(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + retriever_service.ListCorporaResponse( + corpora=[ + retriever.Corpus(), + retriever.Corpus(), + retriever.Corpus(), + ], + next_page_token="abc", + ), + retriever_service.ListCorporaResponse( + corpora=[], + next_page_token="def", + ), + retriever_service.ListCorporaResponse( + corpora=[ + retriever.Corpus(), + ], + next_page_token="ghi", + ), + retriever_service.ListCorporaResponse( + corpora=[ + retriever.Corpus(), + retriever.Corpus(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + retriever_service.ListCorporaResponse.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {} + + pager = client.list_corpora(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, retriever.Corpus) for i in results) + + pages = list(client.list_corpora(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_query_corpus_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.query_corpus in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.query_corpus] = mock_rpc + + request = {} + client.query_corpus(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_corpus(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_query_corpus_rest_required_fields( + request_type=retriever_service.QueryCorpusRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request_init["name"] = "" + request_init["query"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).query_corpus._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + jsonified_request["query"] = "query_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).query_corpus._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + assert "query" in jsonified_request + assert jsonified_request["query"] == "query_value" + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = retriever_service.QueryCorpusResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever_service.QueryCorpusResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.query_corpus(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_query_corpus_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.query_corpus._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "name", + "query", + ) + ) + ) + + +def test_create_document_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_document in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_document] = mock_rpc + + request = {} + client.create_document(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_document(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_create_document_rest_required_fields( + request_type=retriever_service.CreateDocumentRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_document._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_document._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = retriever.Document() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.create_document(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_document_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.create_document._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "document", + ) + ) + ) + + +def test_create_document_rest_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Document() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "corpora/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + document=retriever.Document(name="name_value"), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = retriever.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.create_document(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{parent=corpora/*}/documents" % client.transport._host, args[1] + ) + + +def test_create_document_rest_flattened_error(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_document( + retriever_service.CreateDocumentRequest(), + parent="parent_value", + document=retriever.Document(name="name_value"), + ) + + +def test_get_document_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_document in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_document] = mock_rpc + + request = {} + client.get_document(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_document(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_get_document_rest_required_fields( + request_type=retriever_service.GetDocumentRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_document._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_document._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = retriever.Document() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_document(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_document_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_document._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +def test_get_document_rest_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Document() + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "corpora/sample1/documents/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = retriever.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.get_document(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{name=corpora/*/documents/*}" % client.transport._host, args[1] + ) + + +def test_get_document_rest_flattened_error(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_document( + retriever_service.GetDocumentRequest(), + name="name_value", + ) + + +def test_update_document_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_document in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_document] = mock_rpc + + request = {} + client.update_document(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_document(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_update_document_rest_required_fields( + request_type=retriever_service.UpdateDocumentRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_document._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_document._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("update_mask",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = retriever.Document() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "patch", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.update_document(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_update_document_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.update_document._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(("updateMask",)) + & set( + ( + "document", + "updateMask", + ) + ) + ) + + +def test_update_document_rest_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Document() + + # get arguments that satisfy an http rule for this method + sample_request = {"document": {"name": "corpora/sample1/documents/sample2"}} + + # get truthy value for each flattened field + mock_args = dict( + document=retriever.Document(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = retriever.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.update_document(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{document.name=corpora/*/documents/*}" % client.transport._host, + args[1], + ) + + +def test_update_document_rest_flattened_error(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_document( + retriever_service.UpdateDocumentRequest(), + document=retriever.Document(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +def test_delete_document_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_document in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_document] = mock_rpc + + request = {} + client.delete_document(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_document(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_delete_document_rest_required_fields( + request_type=retriever_service.DeleteDocumentRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_document._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_document._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("force",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.delete_document(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_delete_document_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.delete_document._get_unset_required_fields({}) + assert set(unset_fields) == (set(("force",)) & set(("name",))) + + +def test_delete_document_rest_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "corpora/sample1/documents/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.delete_document(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{name=corpora/*/documents/*}" % client.transport._host, args[1] + ) + + +def test_delete_document_rest_flattened_error(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_document( + retriever_service.DeleteDocumentRequest(), + name="name_value", + ) + + +def test_list_documents_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_documents in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_documents] = mock_rpc + + request = {} + client.list_documents(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_documents(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_list_documents_rest_required_fields( + request_type=retriever_service.ListDocumentsRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_documents._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_documents._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "page_size", + "page_token", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = retriever_service.ListDocumentsResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever_service.ListDocumentsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.list_documents(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_list_documents_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.list_documents._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "pageSize", + "pageToken", + ) + ) + & set(("parent",)) + ) + + +def test_list_documents_rest_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever_service.ListDocumentsResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "corpora/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = retriever_service.ListDocumentsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.list_documents(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{parent=corpora/*}/documents" % client.transport._host, args[1] + ) + + +def test_list_documents_rest_flattened_error(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_documents( + retriever_service.ListDocumentsRequest(), + parent="parent_value", + ) + + +def test_list_documents_rest_pager(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + retriever_service.ListDocumentsResponse( + documents=[ + retriever.Document(), + retriever.Document(), + retriever.Document(), + ], + next_page_token="abc", + ), + retriever_service.ListDocumentsResponse( + documents=[], + next_page_token="def", + ), + retriever_service.ListDocumentsResponse( + documents=[ + retriever.Document(), + ], + next_page_token="ghi", + ), + retriever_service.ListDocumentsResponse( + documents=[ + retriever.Document(), + retriever.Document(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + retriever_service.ListDocumentsResponse.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"parent": "corpora/sample1"} + + pager = client.list_documents(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, retriever.Document) for i in results) + + pages = list(client.list_documents(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_query_document_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.query_document in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.query_document] = mock_rpc + + request = {} + client.query_document(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.query_document(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_query_document_rest_required_fields( + request_type=retriever_service.QueryDocumentRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request_init["name"] = "" + request_init["query"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).query_document._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + jsonified_request["query"] = "query_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).query_document._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + assert "query" in jsonified_request + assert jsonified_request["query"] == "query_value" + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = retriever_service.QueryDocumentResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever_service.QueryDocumentResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.query_document(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_query_document_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.query_document._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "name", + "query", + ) + ) + ) + + +def test_create_chunk_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.create_chunk in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.create_chunk] = mock_rpc + + request = {} + client.create_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.create_chunk(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_create_chunk_rest_required_fields( + request_type=retriever_service.CreateChunkRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_chunk._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).create_chunk._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = retriever.Chunk() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Chunk.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.create_chunk(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_create_chunk_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.create_chunk._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "parent", + "chunk", + ) + ) + ) + + +def test_create_chunk_rest_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Chunk() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "corpora/sample1/documents/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + chunk=retriever.Chunk(name="name_value"), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = retriever.Chunk.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.create_chunk(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{parent=corpora/*/documents/*}/chunks" % client.transport._host, + args[1], + ) + + +def test_create_chunk_rest_flattened_error(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.create_chunk( + retriever_service.CreateChunkRequest(), + parent="parent_value", + chunk=retriever.Chunk(name="name_value"), + ) + + +def test_batch_create_chunks_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_create_chunks in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_create_chunks + ] = mock_rpc + + request = {} + client.batch_create_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_create_chunks(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_batch_create_chunks_rest_required_fields( + request_type=retriever_service.BatchCreateChunksRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).batch_create_chunks._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).batch_create_chunks._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = retriever_service.BatchCreateChunksResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever_service.BatchCreateChunksResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.batch_create_chunks(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_batch_create_chunks_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.batch_create_chunks._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("requests",))) + + +def test_get_chunk_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.get_chunk in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.get_chunk] = mock_rpc + + request = {} + client.get_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.get_chunk(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_get_chunk_rest_required_fields(request_type=retriever_service.GetChunkRequest): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_chunk._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).get_chunk._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = retriever.Chunk() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Chunk.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_chunk(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_get_chunk_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.get_chunk._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +def test_get_chunk_rest_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Chunk() + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "corpora/sample1/documents/sample2/chunks/sample3"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = retriever.Chunk.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.get_chunk(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{name=corpora/*/documents/*/chunks/*}" % client.transport._host, + args[1], + ) + + +def test_get_chunk_rest_flattened_error(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.get_chunk( + retriever_service.GetChunkRequest(), + name="name_value", + ) + + +def test_update_chunk_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.update_chunk in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.update_chunk] = mock_rpc + + request = {} + client.update_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.update_chunk(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_update_chunk_rest_required_fields( + request_type=retriever_service.UpdateChunkRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_chunk._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).update_chunk._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set(("update_mask",)) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = retriever.Chunk() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "patch", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Chunk.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.update_chunk(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_update_chunk_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.update_chunk._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(("updateMask",)) + & set( + ( + "chunk", + "updateMask", + ) + ) + ) + + +def test_update_chunk_rest_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Chunk() + + # get arguments that satisfy an http rule for this method + sample_request = { + "chunk": {"name": "corpora/sample1/documents/sample2/chunks/sample3"} + } + + # get truthy value for each flattened field + mock_args = dict( + chunk=retriever.Chunk(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = retriever.Chunk.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.update_chunk(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{chunk.name=corpora/*/documents/*/chunks/*}" + % client.transport._host, + args[1], + ) + + +def test_update_chunk_rest_flattened_error(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.update_chunk( + retriever_service.UpdateChunkRequest(), + chunk=retriever.Chunk(name="name_value"), + update_mask=field_mask_pb2.FieldMask(paths=["paths_value"]), + ) + + +def test_batch_update_chunks_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_update_chunks in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_update_chunks + ] = mock_rpc + + request = {} + client.batch_update_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_update_chunks(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_batch_update_chunks_rest_required_fields( + request_type=retriever_service.BatchUpdateChunksRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).batch_update_chunks._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).batch_update_chunks._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = retriever_service.BatchUpdateChunksResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever_service.BatchUpdateChunksResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.batch_update_chunks(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_batch_update_chunks_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.batch_update_chunks._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("requests",))) + + +def test_delete_chunk_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.delete_chunk in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.delete_chunk] = mock_rpc + + request = {} + client.delete_chunk(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.delete_chunk(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_delete_chunk_rest_required_fields( + request_type=retriever_service.DeleteChunkRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request_init["name"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_chunk._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["name"] = "name_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).delete_chunk._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "name" in jsonified_request + assert jsonified_request["name"] == "name_value" + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "delete", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.delete_chunk(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_delete_chunk_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.delete_chunk._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("name",))) + + +def test_delete_chunk_rest_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # get arguments that satisfy an http rule for this method + sample_request = {"name": "corpora/sample1/documents/sample2/chunks/sample3"} + + # get truthy value for each flattened field + mock_args = dict( + name="name_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.delete_chunk(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{name=corpora/*/documents/*/chunks/*}" % client.transport._host, + args[1], + ) + + +def test_delete_chunk_rest_flattened_error(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.delete_chunk( + retriever_service.DeleteChunkRequest(), + name="name_value", + ) + + +def test_batch_delete_chunks_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._transport.batch_delete_chunks in client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_delete_chunks + ] = mock_rpc + + request = {} + client.batch_delete_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_delete_chunks(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_batch_delete_chunks_rest_required_fields( + request_type=retriever_service.BatchDeleteChunksRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).batch_delete_chunks._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).batch_delete_chunks._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = None + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + json_return_value = "" + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.batch_delete_chunks(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_batch_delete_chunks_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.batch_delete_chunks._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("requests",))) + + +def test_list_chunks_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.list_chunks in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.list_chunks] = mock_rpc + + request = {} + client.list_chunks(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.list_chunks(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_list_chunks_rest_required_fields( + request_type=retriever_service.ListChunksRequest, +): + transport_class = transports.RetrieverServiceRestTransport + + request_init = {} + request_init["parent"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_chunks._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["parent"] = "parent_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).list_chunks._get_unset_required_fields(jsonified_request) + # Check that path parameters and body parameters are not mixing in. + assert not set(unset_fields) - set( + ( + "page_size", + "page_token", + ) + ) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "parent" in jsonified_request + assert jsonified_request["parent"] == "parent_value" + + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = retriever_service.ListChunksResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "get", + "query_params": pb_request, + } + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever_service.ListChunksResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.list_chunks(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_list_chunks_rest_unset_required_fields(): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.list_chunks._get_unset_required_fields({}) + assert set(unset_fields) == ( + set( + ( + "pageSize", + "pageToken", + ) + ) + & set(("parent",)) + ) + + +def test_list_chunks_rest_flattened(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever_service.ListChunksResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"parent": "corpora/sample1/documents/sample2"} + + # get truthy value for each flattened field + mock_args = dict( + parent="parent_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = retriever_service.ListChunksResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.list_chunks(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{parent=corpora/*/documents/*}/chunks" % client.transport._host, + args[1], + ) + + +def test_list_chunks_rest_flattened_error(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.list_chunks( + retriever_service.ListChunksRequest(), + parent="parent_value", + ) + + +def test_list_chunks_rest_pager(transport: str = "rest"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # TODO(kbandes): remove this mock unless there's a good reason for it. + # with mock.patch.object(path_template, 'transcode') as transcode: + # Set the response as a series of pages + response = ( + retriever_service.ListChunksResponse( + chunks=[ + retriever.Chunk(), + retriever.Chunk(), + retriever.Chunk(), + ], + next_page_token="abc", + ), + retriever_service.ListChunksResponse( + chunks=[], + next_page_token="def", + ), + retriever_service.ListChunksResponse( + chunks=[ + retriever.Chunk(), + ], + next_page_token="ghi", + ), + retriever_service.ListChunksResponse( + chunks=[ + retriever.Chunk(), + retriever.Chunk(), + ], + ), + ) + # Two responses for two calls + response = response + response + + # Wrap the values into proper Response objs + response = tuple( + retriever_service.ListChunksResponse.to_json(x) for x in response + ) + return_values = tuple(Response() for i in response) + for return_val, response_val in zip(return_values, response): + return_val._content = response_val.encode("UTF-8") + return_val.status_code = 200 + req.side_effect = return_values + + sample_request = {"parent": "corpora/sample1/documents/sample2"} + + pager = client.list_chunks(request=sample_request) + + results = list(pager) + assert len(results) == 6 + assert all(isinstance(i, retriever.Chunk) for i in results) + + pages = list(client.list_chunks(request=sample_request).pages) + for page_, token in zip(pages, ["abc", "def", "ghi", ""]): + assert page_.raw_page.next_page_token == token + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.RetrieverServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.RetrieverServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = RetrieverServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.RetrieverServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = RetrieverServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = RetrieverServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.RetrieverServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = RetrieverServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.RetrieverServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = RetrieverServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.RetrieverServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.RetrieverServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.RetrieverServiceGrpcTransport, + transports.RetrieverServiceGrpcAsyncIOTransport, + transports.RetrieverServiceRestTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_kind_grpc(): + transport = RetrieverServiceClient.get_transport_class("grpc")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "grpc" + + +def test_initialize_client_w_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_create_corpus_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.create_corpus), "__call__") as call: + call.return_value = retriever.Corpus() + client.create_corpus(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.CreateCorpusRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_corpus_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_corpus), "__call__") as call: + call.return_value = retriever.Corpus() + client.get_corpus(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.GetCorpusRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_update_corpus_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.update_corpus), "__call__") as call: + call.return_value = retriever.Corpus() + client.update_corpus(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.UpdateCorpusRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_delete_corpus_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.delete_corpus), "__call__") as call: + call.return_value = None + client.delete_corpus(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.DeleteCorpusRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_corpora_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_corpora), "__call__") as call: + call.return_value = retriever_service.ListCorporaResponse() + client.list_corpora(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.ListCorporaRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_query_corpus_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.query_corpus), "__call__") as call: + call.return_value = retriever_service.QueryCorpusResponse() + client.query_corpus(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.QueryCorpusRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_create_document_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.create_document), "__call__") as call: + call.return_value = retriever.Document() + client.create_document(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.CreateDocumentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_document_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_document), "__call__") as call: + call.return_value = retriever.Document() + client.get_document(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.GetDocumentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_update_document_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.update_document), "__call__") as call: + call.return_value = retriever.Document() + client.update_document(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.UpdateDocumentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_delete_document_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.delete_document), "__call__") as call: + call.return_value = None + client.delete_document(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.DeleteDocumentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_documents_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_documents), "__call__") as call: + call.return_value = retriever_service.ListDocumentsResponse() + client.list_documents(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.ListDocumentsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_query_document_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.query_document), "__call__") as call: + call.return_value = retriever_service.QueryDocumentResponse() + client.query_document(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.QueryDocumentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_create_chunk_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.create_chunk), "__call__") as call: + call.return_value = retriever.Chunk() + client.create_chunk(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.CreateChunkRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_batch_create_chunks_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_chunks), "__call__" + ) as call: + call.return_value = retriever_service.BatchCreateChunksResponse() + client.batch_create_chunks(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.BatchCreateChunksRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_chunk_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_chunk), "__call__") as call: + call.return_value = retriever.Chunk() + client.get_chunk(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.GetChunkRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_update_chunk_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.update_chunk), "__call__") as call: + call.return_value = retriever.Chunk() + client.update_chunk(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.UpdateChunkRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_batch_update_chunks_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.batch_update_chunks), "__call__" + ) as call: + call.return_value = retriever_service.BatchUpdateChunksResponse() + client.batch_update_chunks(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.BatchUpdateChunksRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_delete_chunk_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.delete_chunk), "__call__") as call: + call.return_value = None + client.delete_chunk(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.DeleteChunkRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_batch_delete_chunks_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.batch_delete_chunks), "__call__" + ) as call: + call.return_value = None + client.batch_delete_chunks(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.BatchDeleteChunksRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_chunks_empty_call_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_chunks), "__call__") as call: + call.return_value = retriever_service.ListChunksResponse() + client.list_chunks(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.ListChunksRequest() + + assert args[0] == request_msg + + +def test_transport_kind_grpc_asyncio(): + transport = RetrieverServiceAsyncClient.get_transport_class("grpc_asyncio")( + credentials=async_anonymous_credentials() + ) + assert transport.kind == "grpc_asyncio" + + +def test_initialize_client_w_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_create_corpus_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.create_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Corpus( + name="name_value", + display_name="display_name_value", + ) + ) + await client.create_corpus(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.CreateCorpusRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_get_corpus_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Corpus( + name="name_value", + display_name="display_name_value", + ) + ) + await client.get_corpus(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.GetCorpusRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_update_corpus_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.update_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Corpus( + name="name_value", + display_name="display_name_value", + ) + ) + await client.update_corpus(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.UpdateCorpusRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_delete_corpus_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.delete_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_corpus(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.DeleteCorpusRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_list_corpora_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_corpora), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.ListCorporaResponse( + next_page_token="next_page_token_value", + ) + ) + await client.list_corpora(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.ListCorporaRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_query_corpus_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.query_corpus), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.QueryCorpusResponse() + ) + await client.query_corpus(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.QueryCorpusRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_create_document_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.create_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Document( + name="name_value", + display_name="display_name_value", + ) + ) + await client.create_document(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.CreateDocumentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_get_document_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Document( + name="name_value", + display_name="display_name_value", + ) + ) + await client.get_document(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.GetDocumentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_update_document_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.update_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Document( + name="name_value", + display_name="display_name_value", + ) + ) + await client.update_document(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.UpdateDocumentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_delete_document_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.delete_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_document(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.DeleteDocumentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_list_documents_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_documents), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.ListDocumentsResponse( + next_page_token="next_page_token_value", + ) + ) + await client.list_documents(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.ListDocumentsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_query_document_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.query_document), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.QueryDocumentResponse() + ) + await client.query_document(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.QueryDocumentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_create_chunk_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.create_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Chunk( + name="name_value", + state=retriever.Chunk.State.STATE_PENDING_PROCESSING, + ) + ) + await client.create_chunk(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.CreateChunkRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_batch_create_chunks_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_chunks), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.BatchCreateChunksResponse() + ) + await client.batch_create_chunks(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.BatchCreateChunksRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_get_chunk_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Chunk( + name="name_value", + state=retriever.Chunk.State.STATE_PENDING_PROCESSING, + ) + ) + await client.get_chunk(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.GetChunkRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_update_chunk_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.update_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever.Chunk( + name="name_value", + state=retriever.Chunk.State.STATE_PENDING_PROCESSING, + ) + ) + await client.update_chunk(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.UpdateChunkRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_batch_update_chunks_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.batch_update_chunks), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.BatchUpdateChunksResponse() + ) + await client.batch_update_chunks(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.BatchUpdateChunksRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_delete_chunk_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.delete_chunk), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.delete_chunk(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.DeleteChunkRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_batch_delete_chunks_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.batch_delete_chunks), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall(None) + await client.batch_delete_chunks(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.BatchDeleteChunksRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_list_chunks_empty_call_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_chunks), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + retriever_service.ListChunksResponse( + next_page_token="next_page_token_value", + ) + ) + await client.list_chunks(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.ListChunksRequest() + + assert args[0] == request_msg + + +def test_transport_kind_rest(): + transport = RetrieverServiceClient.get_transport_class("rest")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "rest" + + +def test_create_corpus_rest_bad_request( + request_type=retriever_service.CreateCorpusRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.create_corpus(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.CreateCorpusRequest, + dict, + ], +) +def test_create_corpus_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {} + request_init["corpus"] = { + "name": "name_value", + "display_name": "display_name_value", + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = retriever_service.CreateCorpusRequest.meta.fields["corpus"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["corpus"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["corpus"][field])): + del request_init["corpus"][field][i][subfield] + else: + del request_init["corpus"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Corpus( + name="name_value", + display_name="display_name_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Corpus.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.create_corpus(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Corpus) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_corpus_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_create_corpus" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_create_corpus" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.CreateCorpusRequest.pb( + retriever_service.CreateCorpusRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever.Corpus.to_json(retriever.Corpus()) + req.return_value.content = return_value + + request = retriever_service.CreateCorpusRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever.Corpus() + + client.create_corpus( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_corpus_rest_bad_request(request_type=retriever_service.GetCorpusRequest): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_corpus(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.GetCorpusRequest, + dict, + ], +) +def test_get_corpus_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Corpus( + name="name_value", + display_name="display_name_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Corpus.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.get_corpus(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Corpus) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_corpus_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_get_corpus" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_get_corpus" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.GetCorpusRequest.pb( + retriever_service.GetCorpusRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever.Corpus.to_json(retriever.Corpus()) + req.return_value.content = return_value + + request = retriever_service.GetCorpusRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever.Corpus() + + client.get_corpus( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_update_corpus_rest_bad_request( + request_type=retriever_service.UpdateCorpusRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"corpus": {"name": "corpora/sample1"}} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.update_corpus(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.UpdateCorpusRequest, + dict, + ], +) +def test_update_corpus_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"corpus": {"name": "corpora/sample1"}} + request_init["corpus"] = { + "name": "corpora/sample1", + "display_name": "display_name_value", + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = retriever_service.UpdateCorpusRequest.meta.fields["corpus"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["corpus"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["corpus"][field])): + del request_init["corpus"][field][i][subfield] + else: + del request_init["corpus"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Corpus( + name="name_value", + display_name="display_name_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Corpus.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.update_corpus(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Corpus) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_update_corpus_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_update_corpus" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_update_corpus" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.UpdateCorpusRequest.pb( + retriever_service.UpdateCorpusRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever.Corpus.to_json(retriever.Corpus()) + req.return_value.content = return_value + + request = retriever_service.UpdateCorpusRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever.Corpus() + + client.update_corpus( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_delete_corpus_rest_bad_request( + request_type=retriever_service.DeleteCorpusRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.delete_corpus(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.DeleteCorpusRequest, + dict, + ], +) +def test_delete_corpus_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = "" + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.delete_corpus(request) + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_delete_corpus_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_delete_corpus" + ) as pre: + pre.assert_not_called() + pb_message = retriever_service.DeleteCorpusRequest.pb( + retriever_service.DeleteCorpusRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + request = retriever_service.DeleteCorpusRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.delete_corpus( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + + +def test_list_corpora_rest_bad_request( + request_type=retriever_service.ListCorporaRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_corpora(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.ListCorporaRequest, + dict, + ], +) +def test_list_corpora_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever_service.ListCorporaResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever_service.ListCorporaResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.list_corpora(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListCorporaPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_corpora_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_list_corpora" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_list_corpora" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.ListCorporaRequest.pb( + retriever_service.ListCorporaRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever_service.ListCorporaResponse.to_json( + retriever_service.ListCorporaResponse() + ) + req.return_value.content = return_value + + request = retriever_service.ListCorporaRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever_service.ListCorporaResponse() + + client.list_corpora( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_query_corpus_rest_bad_request( + request_type=retriever_service.QueryCorpusRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.query_corpus(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.QueryCorpusRequest, + dict, + ], +) +def test_query_corpus_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever_service.QueryCorpusResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever_service.QueryCorpusResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.query_corpus(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever_service.QueryCorpusResponse) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_query_corpus_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_query_corpus" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_query_corpus" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.QueryCorpusRequest.pb( + retriever_service.QueryCorpusRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever_service.QueryCorpusResponse.to_json( + retriever_service.QueryCorpusResponse() + ) + req.return_value.content = return_value + + request = retriever_service.QueryCorpusRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever_service.QueryCorpusResponse() + + client.query_corpus( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_create_document_rest_bad_request( + request_type=retriever_service.CreateDocumentRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"parent": "corpora/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.create_document(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.CreateDocumentRequest, + dict, + ], +) +def test_create_document_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "corpora/sample1"} + request_init["document"] = { + "name": "name_value", + "display_name": "display_name_value", + "custom_metadata": [ + { + "string_value": "string_value_value", + "string_list_value": {"values": ["values_value1", "values_value2"]}, + "numeric_value": 0.1391, + "key": "key_value", + } + ], + "update_time": {"seconds": 751, "nanos": 543}, + "create_time": {}, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = retriever_service.CreateDocumentRequest.meta.fields["document"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["document"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["document"][field])): + del request_init["document"][field][i][subfield] + else: + del request_init["document"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Document( + name="name_value", + display_name="display_name_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.create_document(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Document) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_document_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_create_document" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_create_document" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.CreateDocumentRequest.pb( + retriever_service.CreateDocumentRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever.Document.to_json(retriever.Document()) + req.return_value.content = return_value + + request = retriever_service.CreateDocumentRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever.Document() + + client.create_document( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_document_rest_bad_request( + request_type=retriever_service.GetDocumentRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1/documents/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_document(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.GetDocumentRequest, + dict, + ], +) +def test_get_document_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1/documents/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Document( + name="name_value", + display_name="display_name_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.get_document(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Document) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_document_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_get_document" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_get_document" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.GetDocumentRequest.pb( + retriever_service.GetDocumentRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever.Document.to_json(retriever.Document()) + req.return_value.content = return_value + + request = retriever_service.GetDocumentRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever.Document() + + client.get_document( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_update_document_rest_bad_request( + request_type=retriever_service.UpdateDocumentRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"document": {"name": "corpora/sample1/documents/sample2"}} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.update_document(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.UpdateDocumentRequest, + dict, + ], +) +def test_update_document_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"document": {"name": "corpora/sample1/documents/sample2"}} + request_init["document"] = { + "name": "corpora/sample1/documents/sample2", + "display_name": "display_name_value", + "custom_metadata": [ + { + "string_value": "string_value_value", + "string_list_value": {"values": ["values_value1", "values_value2"]}, + "numeric_value": 0.1391, + "key": "key_value", + } + ], + "update_time": {"seconds": 751, "nanos": 543}, + "create_time": {}, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = retriever_service.UpdateDocumentRequest.meta.fields["document"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["document"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["document"][field])): + del request_init["document"][field][i][subfield] + else: + del request_init["document"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Document( + name="name_value", + display_name="display_name_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Document.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.update_document(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Document) + assert response.name == "name_value" + assert response.display_name == "display_name_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_update_document_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_update_document" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_update_document" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.UpdateDocumentRequest.pb( + retriever_service.UpdateDocumentRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever.Document.to_json(retriever.Document()) + req.return_value.content = return_value + + request = retriever_service.UpdateDocumentRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever.Document() + + client.update_document( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_delete_document_rest_bad_request( + request_type=retriever_service.DeleteDocumentRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1/documents/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.delete_document(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.DeleteDocumentRequest, + dict, + ], +) +def test_delete_document_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1/documents/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = "" + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.delete_document(request) + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_delete_document_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_delete_document" + ) as pre: + pre.assert_not_called() + pb_message = retriever_service.DeleteDocumentRequest.pb( + retriever_service.DeleteDocumentRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + request = retriever_service.DeleteDocumentRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.delete_document( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + + +def test_list_documents_rest_bad_request( + request_type=retriever_service.ListDocumentsRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"parent": "corpora/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_documents(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.ListDocumentsRequest, + dict, + ], +) +def test_list_documents_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "corpora/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever_service.ListDocumentsResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever_service.ListDocumentsResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.list_documents(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListDocumentsPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_documents_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_list_documents" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_list_documents" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.ListDocumentsRequest.pb( + retriever_service.ListDocumentsRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever_service.ListDocumentsResponse.to_json( + retriever_service.ListDocumentsResponse() + ) + req.return_value.content = return_value + + request = retriever_service.ListDocumentsRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever_service.ListDocumentsResponse() + + client.list_documents( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_query_document_rest_bad_request( + request_type=retriever_service.QueryDocumentRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1/documents/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.query_document(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.QueryDocumentRequest, + dict, + ], +) +def test_query_document_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1/documents/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever_service.QueryDocumentResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever_service.QueryDocumentResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.query_document(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever_service.QueryDocumentResponse) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_query_document_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_query_document" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_query_document" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.QueryDocumentRequest.pb( + retriever_service.QueryDocumentRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever_service.QueryDocumentResponse.to_json( + retriever_service.QueryDocumentResponse() + ) + req.return_value.content = return_value + + request = retriever_service.QueryDocumentRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever_service.QueryDocumentResponse() + + client.query_document( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_create_chunk_rest_bad_request( + request_type=retriever_service.CreateChunkRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"parent": "corpora/sample1/documents/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.create_chunk(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.CreateChunkRequest, + dict, + ], +) +def test_create_chunk_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "corpora/sample1/documents/sample2"} + request_init["chunk"] = { + "name": "name_value", + "data": {"string_value": "string_value_value"}, + "custom_metadata": [ + { + "string_value": "string_value_value", + "string_list_value": {"values": ["values_value1", "values_value2"]}, + "numeric_value": 0.1391, + "key": "key_value", + } + ], + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "state": 1, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = retriever_service.CreateChunkRequest.meta.fields["chunk"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["chunk"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["chunk"][field])): + del request_init["chunk"][field][i][subfield] + else: + del request_init["chunk"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Chunk( + name="name_value", + state=retriever.Chunk.State.STATE_PENDING_PROCESSING, + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Chunk.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.create_chunk(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Chunk) + assert response.name == "name_value" + assert response.state == retriever.Chunk.State.STATE_PENDING_PROCESSING + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_create_chunk_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_create_chunk" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_create_chunk" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.CreateChunkRequest.pb( + retriever_service.CreateChunkRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever.Chunk.to_json(retriever.Chunk()) + req.return_value.content = return_value + + request = retriever_service.CreateChunkRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever.Chunk() + + client.create_chunk( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_batch_create_chunks_rest_bad_request( + request_type=retriever_service.BatchCreateChunksRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"parent": "corpora/sample1/documents/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.batch_create_chunks(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.BatchCreateChunksRequest, + dict, + ], +) +def test_batch_create_chunks_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "corpora/sample1/documents/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever_service.BatchCreateChunksResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever_service.BatchCreateChunksResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.batch_create_chunks(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever_service.BatchCreateChunksResponse) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_batch_create_chunks_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_batch_create_chunks" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_batch_create_chunks" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.BatchCreateChunksRequest.pb( + retriever_service.BatchCreateChunksRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever_service.BatchCreateChunksResponse.to_json( + retriever_service.BatchCreateChunksResponse() + ) + req.return_value.content = return_value + + request = retriever_service.BatchCreateChunksRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever_service.BatchCreateChunksResponse() + + client.batch_create_chunks( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_chunk_rest_bad_request(request_type=retriever_service.GetChunkRequest): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1/documents/sample2/chunks/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_chunk(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.GetChunkRequest, + dict, + ], +) +def test_get_chunk_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1/documents/sample2/chunks/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Chunk( + name="name_value", + state=retriever.Chunk.State.STATE_PENDING_PROCESSING, + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Chunk.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.get_chunk(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Chunk) + assert response.name == "name_value" + assert response.state == retriever.Chunk.State.STATE_PENDING_PROCESSING + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_get_chunk_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_get_chunk" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_get_chunk" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.GetChunkRequest.pb( + retriever_service.GetChunkRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever.Chunk.to_json(retriever.Chunk()) + req.return_value.content = return_value + + request = retriever_service.GetChunkRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever.Chunk() + + client.get_chunk( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_update_chunk_rest_bad_request( + request_type=retriever_service.UpdateChunkRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = { + "chunk": {"name": "corpora/sample1/documents/sample2/chunks/sample3"} + } + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.update_chunk(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.UpdateChunkRequest, + dict, + ], +) +def test_update_chunk_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = { + "chunk": {"name": "corpora/sample1/documents/sample2/chunks/sample3"} + } + request_init["chunk"] = { + "name": "corpora/sample1/documents/sample2/chunks/sample3", + "data": {"string_value": "string_value_value"}, + "custom_metadata": [ + { + "string_value": "string_value_value", + "string_list_value": {"values": ["values_value1", "values_value2"]}, + "numeric_value": 0.1391, + "key": "key_value", + } + ], + "create_time": {"seconds": 751, "nanos": 543}, + "update_time": {}, + "state": 1, + } + # The version of a generated dependency at test runtime may differ from the version used during generation. + # Delete any fields which are not present in the current runtime dependency + # See https://github.com/googleapis/gapic-generator-python/issues/1748 + + # Determine if the message type is proto-plus or protobuf + test_field = retriever_service.UpdateChunkRequest.meta.fields["chunk"] + + def get_message_fields(field): + # Given a field which is a message (composite type), return a list with + # all the fields of the message. + # If the field is not a composite type, return an empty list. + message_fields = [] + + if hasattr(field, "message") and field.message: + is_field_type_proto_plus_type = not hasattr(field.message, "DESCRIPTOR") + + if is_field_type_proto_plus_type: + message_fields = field.message.meta.fields.values() + # Add `# pragma: NO COVER` because there may not be any `*_pb2` field types + else: # pragma: NO COVER + message_fields = field.message.DESCRIPTOR.fields + return message_fields + + runtime_nested_fields = [ + (field.name, nested_field.name) + for field in get_message_fields(test_field) + for nested_field in get_message_fields(field) + ] + + subfields_not_in_runtime = [] + + # For each item in the sample request, create a list of sub fields which are not present at runtime + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for field, value in request_init["chunk"].items(): # pragma: NO COVER + result = None + is_repeated = False + # For repeated fields + if isinstance(value, list) and len(value): + is_repeated = True + result = value[0] + # For fields where the type is another message + if isinstance(value, dict): + result = value + + if result and hasattr(result, "keys"): + for subfield in result.keys(): + if (field, subfield) not in runtime_nested_fields: + subfields_not_in_runtime.append( + { + "field": field, + "subfield": subfield, + "is_repeated": is_repeated, + } + ) + + # Remove fields from the sample request which are not present in the runtime version of the dependency + # Add `# pragma: NO COVER` because this test code will not run if all subfields are present at runtime + for subfield_to_delete in subfields_not_in_runtime: # pragma: NO COVER + field = subfield_to_delete.get("field") + field_repeated = subfield_to_delete.get("is_repeated") + subfield = subfield_to_delete.get("subfield") + if subfield: + if field_repeated: + for i in range(0, len(request_init["chunk"][field])): + del request_init["chunk"][field][i][subfield] + else: + del request_init["chunk"][field][subfield] + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever.Chunk( + name="name_value", + state=retriever.Chunk.State.STATE_PENDING_PROCESSING, + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever.Chunk.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.update_chunk(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever.Chunk) + assert response.name == "name_value" + assert response.state == retriever.Chunk.State.STATE_PENDING_PROCESSING + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_update_chunk_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_update_chunk" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_update_chunk" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.UpdateChunkRequest.pb( + retriever_service.UpdateChunkRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever.Chunk.to_json(retriever.Chunk()) + req.return_value.content = return_value + + request = retriever_service.UpdateChunkRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever.Chunk() + + client.update_chunk( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_batch_update_chunks_rest_bad_request( + request_type=retriever_service.BatchUpdateChunksRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"parent": "corpora/sample1/documents/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.batch_update_chunks(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.BatchUpdateChunksRequest, + dict, + ], +) +def test_batch_update_chunks_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "corpora/sample1/documents/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever_service.BatchUpdateChunksResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever_service.BatchUpdateChunksResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.batch_update_chunks(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, retriever_service.BatchUpdateChunksResponse) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_batch_update_chunks_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_batch_update_chunks" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_batch_update_chunks" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.BatchUpdateChunksRequest.pb( + retriever_service.BatchUpdateChunksRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever_service.BatchUpdateChunksResponse.to_json( + retriever_service.BatchUpdateChunksResponse() + ) + req.return_value.content = return_value + + request = retriever_service.BatchUpdateChunksRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever_service.BatchUpdateChunksResponse() + + client.batch_update_chunks( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_delete_chunk_rest_bad_request( + request_type=retriever_service.DeleteChunkRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1/documents/sample2/chunks/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.delete_chunk(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.DeleteChunkRequest, + dict, + ], +) +def test_delete_chunk_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"name": "corpora/sample1/documents/sample2/chunks/sample3"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = "" + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.delete_chunk(request) + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_delete_chunk_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_delete_chunk" + ) as pre: + pre.assert_not_called() + pb_message = retriever_service.DeleteChunkRequest.pb( + retriever_service.DeleteChunkRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + request = retriever_service.DeleteChunkRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.delete_chunk( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + + +def test_batch_delete_chunks_rest_bad_request( + request_type=retriever_service.BatchDeleteChunksRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"parent": "corpora/sample1/documents/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.batch_delete_chunks(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.BatchDeleteChunksRequest, + dict, + ], +) +def test_batch_delete_chunks_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "corpora/sample1/documents/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = None + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = "" + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.batch_delete_chunks(request) + + # Establish that the response is the type that we expect. + assert response is None + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_batch_delete_chunks_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_batch_delete_chunks" + ) as pre: + pre.assert_not_called() + pb_message = retriever_service.BatchDeleteChunksRequest.pb( + retriever_service.BatchDeleteChunksRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + request = retriever_service.BatchDeleteChunksRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + + client.batch_delete_chunks( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + + +def test_list_chunks_rest_bad_request(request_type=retriever_service.ListChunksRequest): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"parent": "corpora/sample1/documents/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_chunks(request) + + +@pytest.mark.parametrize( + "request_type", + [ + retriever_service.ListChunksRequest, + dict, + ], +) +def test_list_chunks_rest_call_success(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"parent": "corpora/sample1/documents/sample2"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = retriever_service.ListChunksResponse( + next_page_token="next_page_token_value", + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = retriever_service.ListChunksResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.list_chunks(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, pagers.ListChunksPager) + assert response.next_page_token == "next_page_token_value" + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_list_chunks_rest_interceptors(null_interceptor): + transport = transports.RetrieverServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.RetrieverServiceRestInterceptor(), + ) + client = RetrieverServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "post_list_chunks" + ) as post, mock.patch.object( + transports.RetrieverServiceRestInterceptor, "pre_list_chunks" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = retriever_service.ListChunksRequest.pb( + retriever_service.ListChunksRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = retriever_service.ListChunksResponse.to_json( + retriever_service.ListChunksResponse() + ) + req.return_value.content = return_value + + request = retriever_service.ListChunksRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = retriever_service.ListChunksResponse() + + client.list_chunks( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_operation_rest_bad_request( + request_type=operations_pb2.GetOperationRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict( + {"name": "tunedModels/sample1/operations/sample2"}, request + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.GetOperationRequest, + dict, + ], +) +def test_get_operation_rest(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1/operations/sample2"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_operation(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_list_operations_rest_bad_request( + request_type=operations_pb2.ListOperationsRequest, +): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict({"name": "tunedModels/sample1"}, request) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_operations(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.ListOperationsRequest, + dict, + ], +) +def test_list_operations_rest(request_type): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.ListOperationsResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.list_operations(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_initialize_client_w_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_create_corpus_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.create_corpus), "__call__") as call: + client.create_corpus(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.CreateCorpusRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_corpus_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_corpus), "__call__") as call: + client.get_corpus(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.GetCorpusRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_update_corpus_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.update_corpus), "__call__") as call: + client.update_corpus(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.UpdateCorpusRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_delete_corpus_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.delete_corpus), "__call__") as call: + client.delete_corpus(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.DeleteCorpusRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_corpora_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_corpora), "__call__") as call: + client.list_corpora(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.ListCorporaRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_query_corpus_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.query_corpus), "__call__") as call: + client.query_corpus(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.QueryCorpusRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_create_document_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.create_document), "__call__") as call: + client.create_document(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.CreateDocumentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_document_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_document), "__call__") as call: + client.get_document(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.GetDocumentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_update_document_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.update_document), "__call__") as call: + client.update_document(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.UpdateDocumentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_delete_document_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.delete_document), "__call__") as call: + client.delete_document(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.DeleteDocumentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_documents_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_documents), "__call__") as call: + client.list_documents(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.ListDocumentsRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_query_document_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.query_document), "__call__") as call: + client.query_document(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.QueryDocumentRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_create_chunk_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.create_chunk), "__call__") as call: + client.create_chunk(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.CreateChunkRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_batch_create_chunks_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.batch_create_chunks), "__call__" + ) as call: + client.batch_create_chunks(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.BatchCreateChunksRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_get_chunk_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.get_chunk), "__call__") as call: + client.get_chunk(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.GetChunkRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_update_chunk_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.update_chunk), "__call__") as call: + client.update_chunk(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.UpdateChunkRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_batch_update_chunks_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.batch_update_chunks), "__call__" + ) as call: + client.batch_update_chunks(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.BatchUpdateChunksRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_delete_chunk_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.delete_chunk), "__call__") as call: + client.delete_chunk(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.DeleteChunkRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_batch_delete_chunks_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.batch_delete_chunks), "__call__" + ) as call: + client.batch_delete_chunks(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.BatchDeleteChunksRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_list_chunks_empty_call_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.list_chunks), "__call__") as call: + client.list_chunks(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = retriever_service.ListChunksRequest() + + assert args[0] == request_msg + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.RetrieverServiceGrpcTransport, + ) + + +def test_retriever_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.RetrieverServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_retriever_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.retriever_service.transports.RetrieverServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.RetrieverServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "create_corpus", + "get_corpus", + "update_corpus", + "delete_corpus", + "list_corpora", + "query_corpus", + "create_document", + "get_document", + "update_document", + "delete_document", + "list_documents", + "query_document", + "create_chunk", + "batch_create_chunks", + "get_chunk", + "update_chunk", + "batch_update_chunks", + "delete_chunk", + "batch_delete_chunks", + "list_chunks", + "get_operation", + "list_operations", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + "kind", + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_retriever_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.ai.generativelanguage_v1alpha.services.retriever_service.transports.RetrieverServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.RetrieverServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=(), + quota_project_id="octopus", + ) + + +def test_retriever_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.ai.generativelanguage_v1alpha.services.retriever_service.transports.RetrieverServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.RetrieverServiceTransport() + adc.assert_called_once() + + +def test_retriever_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + RetrieverServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=(), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.RetrieverServiceGrpcTransport, + transports.RetrieverServiceGrpcAsyncIOTransport, + ], +) +def test_retriever_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.RetrieverServiceGrpcTransport, + transports.RetrieverServiceGrpcAsyncIOTransport, + transports.RetrieverServiceRestTransport, + ], +) +def test_retriever_service_transport_auth_gdch_credentials(transport_class): + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, "default", autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with(e) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.RetrieverServiceGrpcTransport, grpc_helpers), + (transports.RetrieverServiceGrpcAsyncIOTransport, grpc_helpers_async), + ], +) +def test_retriever_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=(), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.RetrieverServiceGrpcTransport, + transports.RetrieverServiceGrpcAsyncIOTransport, + ], +) +def test_retriever_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_retriever_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.RetrieverServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_retriever_service_host_no_port(transport_name): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_retriever_service_host_with_port(transport_name): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com:8000" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_retriever_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = RetrieverServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = RetrieverServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.create_corpus._session + session2 = client2.transport.create_corpus._session + assert session1 != session2 + session1 = client1.transport.get_corpus._session + session2 = client2.transport.get_corpus._session + assert session1 != session2 + session1 = client1.transport.update_corpus._session + session2 = client2.transport.update_corpus._session + assert session1 != session2 + session1 = client1.transport.delete_corpus._session + session2 = client2.transport.delete_corpus._session + assert session1 != session2 + session1 = client1.transport.list_corpora._session + session2 = client2.transport.list_corpora._session + assert session1 != session2 + session1 = client1.transport.query_corpus._session + session2 = client2.transport.query_corpus._session + assert session1 != session2 + session1 = client1.transport.create_document._session + session2 = client2.transport.create_document._session + assert session1 != session2 + session1 = client1.transport.get_document._session + session2 = client2.transport.get_document._session + assert session1 != session2 + session1 = client1.transport.update_document._session + session2 = client2.transport.update_document._session + assert session1 != session2 + session1 = client1.transport.delete_document._session + session2 = client2.transport.delete_document._session + assert session1 != session2 + session1 = client1.transport.list_documents._session + session2 = client2.transport.list_documents._session + assert session1 != session2 + session1 = client1.transport.query_document._session + session2 = client2.transport.query_document._session + assert session1 != session2 + session1 = client1.transport.create_chunk._session + session2 = client2.transport.create_chunk._session + assert session1 != session2 + session1 = client1.transport.batch_create_chunks._session + session2 = client2.transport.batch_create_chunks._session + assert session1 != session2 + session1 = client1.transport.get_chunk._session + session2 = client2.transport.get_chunk._session + assert session1 != session2 + session1 = client1.transport.update_chunk._session + session2 = client2.transport.update_chunk._session + assert session1 != session2 + session1 = client1.transport.batch_update_chunks._session + session2 = client2.transport.batch_update_chunks._session + assert session1 != session2 + session1 = client1.transport.delete_chunk._session + session2 = client2.transport.delete_chunk._session + assert session1 != session2 + session1 = client1.transport.batch_delete_chunks._session + session2 = client2.transport.batch_delete_chunks._session + assert session1 != session2 + session1 = client1.transport.list_chunks._session + session2 = client2.transport.list_chunks._session + assert session1 != session2 + + +def test_retriever_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.RetrieverServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_retriever_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.RetrieverServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.RetrieverServiceGrpcTransport, + transports.RetrieverServiceGrpcAsyncIOTransport, + ], +) +def test_retriever_service_transport_channel_mtls_with_client_cert_source( + transport_class, +): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [ + transports.RetrieverServiceGrpcTransport, + transports.RetrieverServiceGrpcAsyncIOTransport, + ], +) +def test_retriever_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_chunk_path(): + corpus = "squid" + document = "clam" + chunk = "whelk" + expected = "corpora/{corpus}/documents/{document}/chunks/{chunk}".format( + corpus=corpus, + document=document, + chunk=chunk, + ) + actual = RetrieverServiceClient.chunk_path(corpus, document, chunk) + assert expected == actual + + +def test_parse_chunk_path(): + expected = { + "corpus": "octopus", + "document": "oyster", + "chunk": "nudibranch", + } + path = RetrieverServiceClient.chunk_path(**expected) + + # Check that the path construction is reversible. + actual = RetrieverServiceClient.parse_chunk_path(path) + assert expected == actual + + +def test_corpus_path(): + corpus = "cuttlefish" + expected = "corpora/{corpus}".format( + corpus=corpus, + ) + actual = RetrieverServiceClient.corpus_path(corpus) + assert expected == actual + + +def test_parse_corpus_path(): + expected = { + "corpus": "mussel", + } + path = RetrieverServiceClient.corpus_path(**expected) + + # Check that the path construction is reversible. + actual = RetrieverServiceClient.parse_corpus_path(path) + assert expected == actual + + +def test_document_path(): + corpus = "winkle" + document = "nautilus" + expected = "corpora/{corpus}/documents/{document}".format( + corpus=corpus, + document=document, + ) + actual = RetrieverServiceClient.document_path(corpus, document) + assert expected == actual + + +def test_parse_document_path(): + expected = { + "corpus": "scallop", + "document": "abalone", + } + path = RetrieverServiceClient.document_path(**expected) + + # Check that the path construction is reversible. + actual = RetrieverServiceClient.parse_document_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "squid" + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = RetrieverServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "clam", + } + path = RetrieverServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = RetrieverServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "whelk" + expected = "folders/{folder}".format( + folder=folder, + ) + actual = RetrieverServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "octopus", + } + path = RetrieverServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = RetrieverServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "oyster" + expected = "organizations/{organization}".format( + organization=organization, + ) + actual = RetrieverServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "nudibranch", + } + path = RetrieverServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = RetrieverServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "cuttlefish" + expected = "projects/{project}".format( + project=project, + ) + actual = RetrieverServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "mussel", + } + path = RetrieverServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = RetrieverServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "winkle" + location = "nautilus" + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + actual = RetrieverServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "scallop", + "location": "abalone", + } + path = RetrieverServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = RetrieverServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.RetrieverServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.RetrieverServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = RetrieverServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + +def test_get_operation(transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + response = client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +@pytest.mark.asyncio +async def test_get_operation_async(transport: str = "grpc_asyncio"): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_get_operation_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = operations_pb2.Operation() + + client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_operation_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_get_operation_from_dict(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + + response = client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_operation_from_dict_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_list_operations(transport: str = "grpc"): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + response = client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +@pytest.mark.asyncio +async def test_list_operations_async(transport: str = "grpc_asyncio"): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_list_operations_field_headers(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = operations_pb2.ListOperationsResponse() + + client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_operations_field_headers_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_list_operations_from_dict(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + + response = client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_list_operations_from_dict_async(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_transport_close_grpc(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +@pytest.mark.asyncio +async def test_transport_close_grpc_asyncio(): + client = RetrieverServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close_rest(): + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + with mock.patch.object( + type(getattr(client.transport, "_session")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "rest", + "grpc", + ] + for transport in transports: + client = RetrieverServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (RetrieverServiceClient, transports.RetrieverServiceGrpcTransport), + (RetrieverServiceAsyncClient, transports.RetrieverServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_text_service.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_text_service.py new file mode 100644 index 000000000000..aeace4c2095f --- /dev/null +++ b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1alpha/test_text_service.py @@ -0,0 +1,5101 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os + +# try/except added for compatibility with python < 3.8 +try: + from unittest import mock + from unittest.mock import AsyncMock # pragma: NO COVER +except ImportError: # pragma: NO COVER + import mock + +from collections.abc import AsyncIterable, Iterable +import json +import math + +from google.api_core import api_core_version +from google.protobuf import json_format +import grpc +from grpc.experimental import aio +from proto.marshal.rules import wrappers +from proto.marshal.rules.dates import DurationRule, TimestampRule +import pytest +from requests import PreparedRequest, Request, Response +from requests.sessions import Session + +try: + from google.auth.aio import credentials as ga_credentials_async + + HAS_GOOGLE_AUTH_AIO = True +except ImportError: # pragma: NO COVER + HAS_GOOGLE_AUTH_AIO = False + +from google.api_core import gapic_v1, grpc_helpers, grpc_helpers_async, path_template +from google.api_core import client_options +from google.api_core import exceptions as core_exceptions +from google.api_core import retry as retries +import google.auth +from google.auth import credentials as ga_credentials +from google.auth.exceptions import MutualTLSChannelError +from google.longrunning import operations_pb2 # type: ignore +from google.oauth2 import service_account + +from google.ai.generativelanguage_v1alpha.services.text_service import ( + TextServiceAsyncClient, + TextServiceClient, + transports, +) +from google.ai.generativelanguage_v1alpha.types import safety, text_service + + +async def mock_async_gen(data, chunk_size=1): + for i in range(0, len(data)): # pragma: NO COVER + chunk = data[i : i + chunk_size] + yield chunk.encode("utf-8") + + +def client_cert_source_callback(): + return b"cert bytes", b"key bytes" + + +# TODO: use async auth anon credentials by default once the minimum version of google-auth is upgraded. +# See related issue: https://github.com/googleapis/gapic-generator-python/issues/2107. +def async_anonymous_credentials(): + if HAS_GOOGLE_AUTH_AIO: + return ga_credentials_async.AnonymousCredentials() + return ga_credentials.AnonymousCredentials() + + +# If default endpoint is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint(client): + return ( + "foo.googleapis.com" + if ("localhost" in client.DEFAULT_ENDPOINT) + else client.DEFAULT_ENDPOINT + ) + + +# If default endpoint template is localhost, then default mtls endpoint will be the same. +# This method modifies the default endpoint template so the client can produce a different +# mtls endpoint for endpoint testing purposes. +def modify_default_endpoint_template(client): + return ( + "test.{UNIVERSE_DOMAIN}" + if ("localhost" in client._DEFAULT_ENDPOINT_TEMPLATE) + else client._DEFAULT_ENDPOINT_TEMPLATE + ) + + +def test__get_default_mtls_endpoint(): + api_endpoint = "example.googleapis.com" + api_mtls_endpoint = "example.mtls.googleapis.com" + sandbox_endpoint = "example.sandbox.googleapis.com" + sandbox_mtls_endpoint = "example.mtls.sandbox.googleapis.com" + non_googleapi = "api.example.com" + + assert TextServiceClient._get_default_mtls_endpoint(None) is None + assert ( + TextServiceClient._get_default_mtls_endpoint(api_endpoint) == api_mtls_endpoint + ) + assert ( + TextServiceClient._get_default_mtls_endpoint(api_mtls_endpoint) + == api_mtls_endpoint + ) + assert ( + TextServiceClient._get_default_mtls_endpoint(sandbox_endpoint) + == sandbox_mtls_endpoint + ) + assert ( + TextServiceClient._get_default_mtls_endpoint(sandbox_mtls_endpoint) + == sandbox_mtls_endpoint + ) + assert TextServiceClient._get_default_mtls_endpoint(non_googleapi) == non_googleapi + + +def test__read_environment_variables(): + assert TextServiceClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + assert TextServiceClient._read_environment_variables() == (True, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + assert TextServiceClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + TextServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + assert TextServiceClient._read_environment_variables() == (False, "never", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + assert TextServiceClient._read_environment_variables() == ( + False, + "always", + None, + ) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}): + assert TextServiceClient._read_environment_variables() == (False, "auto", None) + + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + TextServiceClient._read_environment_variables() + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + with mock.patch.dict(os.environ, {"GOOGLE_CLOUD_UNIVERSE_DOMAIN": "foo.com"}): + assert TextServiceClient._read_environment_variables() == ( + False, + "auto", + "foo.com", + ) + + +def test__get_client_cert_source(): + mock_provided_cert_source = mock.Mock() + mock_default_cert_source = mock.Mock() + + assert TextServiceClient._get_client_cert_source(None, False) is None + assert ( + TextServiceClient._get_client_cert_source(mock_provided_cert_source, False) + is None + ) + assert ( + TextServiceClient._get_client_cert_source(mock_provided_cert_source, True) + == mock_provided_cert_source + ) + + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", return_value=True + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_default_cert_source, + ): + assert ( + TextServiceClient._get_client_cert_source(None, True) + is mock_default_cert_source + ) + assert ( + TextServiceClient._get_client_cert_source( + mock_provided_cert_source, "true" + ) + is mock_provided_cert_source + ) + + +@mock.patch.object( + TextServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(TextServiceClient), +) +@mock.patch.object( + TextServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(TextServiceAsyncClient), +) +def test__get_api_endpoint(): + api_override = "foo.com" + mock_client_cert_source = mock.Mock() + default_universe = TextServiceClient._DEFAULT_UNIVERSE + default_endpoint = TextServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = TextServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + assert ( + TextServiceClient._get_api_endpoint( + api_override, mock_client_cert_source, default_universe, "always" + ) + == api_override + ) + assert ( + TextServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "auto" + ) + == TextServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + TextServiceClient._get_api_endpoint(None, None, default_universe, "auto") + == default_endpoint + ) + assert ( + TextServiceClient._get_api_endpoint(None, None, default_universe, "always") + == TextServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + TextServiceClient._get_api_endpoint( + None, mock_client_cert_source, default_universe, "always" + ) + == TextServiceClient.DEFAULT_MTLS_ENDPOINT + ) + assert ( + TextServiceClient._get_api_endpoint(None, None, mock_universe, "never") + == mock_endpoint + ) + assert ( + TextServiceClient._get_api_endpoint(None, None, default_universe, "never") + == default_endpoint + ) + + with pytest.raises(MutualTLSChannelError) as excinfo: + TextServiceClient._get_api_endpoint( + None, mock_client_cert_source, mock_universe, "auto" + ) + assert ( + str(excinfo.value) + == "mTLS is not supported in any universe other than googleapis.com." + ) + + +def test__get_universe_domain(): + client_universe_domain = "foo.com" + universe_domain_env = "bar.com" + + assert ( + TextServiceClient._get_universe_domain( + client_universe_domain, universe_domain_env + ) + == client_universe_domain + ) + assert ( + TextServiceClient._get_universe_domain(None, universe_domain_env) + == universe_domain_env + ) + assert ( + TextServiceClient._get_universe_domain(None, None) + == TextServiceClient._DEFAULT_UNIVERSE + ) + + with pytest.raises(ValueError) as excinfo: + TextServiceClient._get_universe_domain("", None) + assert str(excinfo.value) == "Universe Domain cannot be an empty string." + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (TextServiceClient, "grpc"), + (TextServiceAsyncClient, "grpc_asyncio"), + (TextServiceClient, "rest"), + ], +) +def test_text_service_client_from_service_account_info(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_info" + ) as factory: + factory.return_value = creds + info = {"valid": True} + client = client_class.from_service_account_info(info, transport=transport_name) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_class,transport_name", + [ + (transports.TextServiceGrpcTransport, "grpc"), + (transports.TextServiceGrpcAsyncIOTransport, "grpc_asyncio"), + (transports.TextServiceRestTransport, "rest"), + ], +) +def test_text_service_client_service_account_always_use_jwt( + transport_class, transport_name +): + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=True) + use_jwt.assert_called_once_with(True) + + with mock.patch.object( + service_account.Credentials, "with_always_use_jwt_access", create=True + ) as use_jwt: + creds = service_account.Credentials(None, None, None) + transport = transport_class(credentials=creds, always_use_jwt_access=False) + use_jwt.assert_not_called() + + +@pytest.mark.parametrize( + "client_class,transport_name", + [ + (TextServiceClient, "grpc"), + (TextServiceAsyncClient, "grpc_asyncio"), + (TextServiceClient, "rest"), + ], +) +def test_text_service_client_from_service_account_file(client_class, transport_name): + creds = ga_credentials.AnonymousCredentials() + with mock.patch.object( + service_account.Credentials, "from_service_account_file" + ) as factory: + factory.return_value = creds + client = client_class.from_service_account_file( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + client = client_class.from_service_account_json( + "dummy/file/path.json", transport=transport_name + ) + assert client.transport._credentials == creds + assert isinstance(client, client_class) + + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +def test_text_service_client_get_transport_class(): + transport = TextServiceClient.get_transport_class() + available_transports = [ + transports.TextServiceGrpcTransport, + transports.TextServiceRestTransport, + ] + assert transport in available_transports + + transport = TextServiceClient.get_transport_class("grpc") + assert transport == transports.TextServiceGrpcTransport + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc"), + ( + TextServiceAsyncClient, + transports.TextServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (TextServiceClient, transports.TextServiceRestTransport, "rest"), + ], +) +@mock.patch.object( + TextServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(TextServiceClient), +) +@mock.patch.object( + TextServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(TextServiceAsyncClient), +) +def test_text_service_client_client_options( + client_class, transport_class, transport_name +): + # Check that if channel is provided we won't create a new one. + with mock.patch.object(TextServiceClient, "get_transport_class") as gtc: + transport = transport_class(credentials=ga_credentials.AnonymousCredentials()) + client = client_class(transport=transport) + gtc.assert_not_called() + + # Check that if channel is provided via str we will create a new one. + with mock.patch.object(TextServiceClient, "get_transport_class") as gtc: + client = client_class(transport=transport_name) + gtc.assert_called() + + # Check the case api_endpoint is provided. + options = client_options.ClientOptions(api_endpoint="squid.clam.whelk") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name, client_options=options) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT is + # "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client.DEFAULT_MTLS_ENDPOINT, + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client = client_class(transport=transport_name) + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + # Check the case quota_project_id is provided + options = client_options.ClientOptions(quota_project_id="octopus") + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id="octopus", + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + # Check the case api_endpoint is provided + options = client_options.ClientOptions( + api_audience="https://language.googleapis.com" + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience="https://language.googleapis.com", + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,use_client_cert_env", + [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", "true"), + ( + TextServiceAsyncClient, + transports.TextServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "true", + ), + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", "false"), + ( + TextServiceAsyncClient, + transports.TextServiceGrpcAsyncIOTransport, + "grpc_asyncio", + "false", + ), + (TextServiceClient, transports.TextServiceRestTransport, "rest", "true"), + (TextServiceClient, transports.TextServiceRestTransport, "rest", "false"), + ], +) +@mock.patch.object( + TextServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(TextServiceClient), +) +@mock.patch.object( + TextServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(TextServiceAsyncClient), +) +@mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) +def test_text_service_client_mtls_env_auto( + client_class, transport_class, transport_name, use_client_cert_env +): + # This tests the endpoint autoswitch behavior. Endpoint is autoswitched to the default + # mtls endpoint, if GOOGLE_API_USE_CLIENT_CERTIFICATE is "true" and client cert exists. + + # Check the case client_cert_source is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + options = client_options.ClientOptions( + client_cert_source=client_cert_source_callback + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + + if use_client_cert_env == "false": + expected_client_cert_source = None + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + else: + expected_client_cert_source = client_cert_source_callback + expected_host = client.DEFAULT_MTLS_ENDPOINT + + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case ADC client cert is provided. Whether client cert is used depends on + # GOOGLE_API_USE_CLIENT_CERTIFICATE value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=client_cert_source_callback, + ): + if use_client_cert_env == "false": + expected_host = client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ) + expected_client_cert_source = None + else: + expected_host = client.DEFAULT_MTLS_ENDPOINT + expected_client_cert_source = client_cert_source_callback + + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=expected_host, + scopes=None, + client_cert_source_for_mtls=expected_client_cert_source, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # Check the case client_cert_source and ADC client cert are not provided. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": use_client_cert_env} + ): + with mock.patch.object(transport_class, "__init__") as patched: + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + patched.return_value = None + client = client_class(transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize("client_class", [TextServiceClient, TextServiceAsyncClient]) +@mock.patch.object( + TextServiceClient, "DEFAULT_ENDPOINT", modify_default_endpoint(TextServiceClient) +) +@mock.patch.object( + TextServiceAsyncClient, + "DEFAULT_ENDPOINT", + modify_default_endpoint(TextServiceAsyncClient), +) +def test_text_service_client_get_mtls_endpoint_and_cert_source(client_class): + mock_client_cert_source = mock.Mock() + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "true". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source == mock_client_cert_source + + # Test the case GOOGLE_API_USE_CLIENT_CERTIFICATE is "false". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}): + mock_client_cert_source = mock.Mock() + mock_api_endpoint = "foo" + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=mock_api_endpoint + ) + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source( + options + ) + assert api_endpoint == mock_api_endpoint + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "never". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "always". + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert doesn't exist. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=False, + ): + api_endpoint, cert_source = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_ENDPOINT + assert cert_source is None + + # Test the case GOOGLE_API_USE_MTLS_ENDPOINT is "auto" and default cert exists. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.mtls.has_default_client_cert_source", + return_value=True, + ): + with mock.patch( + "google.auth.transport.mtls.default_client_cert_source", + return_value=mock_client_cert_source, + ): + ( + api_endpoint, + cert_source, + ) = client_class.get_mtls_endpoint_and_cert_source() + assert api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + assert cert_source == mock_client_cert_source + + # Check the case api_endpoint is not provided and GOOGLE_API_USE_MTLS_ENDPOINT has + # unsupported value. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "Unsupported"}): + with pytest.raises(MutualTLSChannelError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_MTLS_ENDPOINT` must be `never`, `auto` or `always`" + ) + + # Check the case GOOGLE_API_USE_CLIENT_CERTIFICATE has unsupported value. + with mock.patch.dict( + os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "Unsupported"} + ): + with pytest.raises(ValueError) as excinfo: + client_class.get_mtls_endpoint_and_cert_source() + + assert ( + str(excinfo.value) + == "Environment variable `GOOGLE_API_USE_CLIENT_CERTIFICATE` must be either `true` or `false`" + ) + + +@pytest.mark.parametrize("client_class", [TextServiceClient, TextServiceAsyncClient]) +@mock.patch.object( + TextServiceClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(TextServiceClient), +) +@mock.patch.object( + TextServiceAsyncClient, + "_DEFAULT_ENDPOINT_TEMPLATE", + modify_default_endpoint_template(TextServiceAsyncClient), +) +def test_text_service_client_client_api_endpoint(client_class): + mock_client_cert_source = client_cert_source_callback + api_override = "foo.com" + default_universe = TextServiceClient._DEFAULT_UNIVERSE + default_endpoint = TextServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=default_universe + ) + mock_universe = "bar.com" + mock_endpoint = TextServiceClient._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=mock_universe + ) + + # If ClientOptions.api_endpoint is set and GOOGLE_API_USE_CLIENT_CERTIFICATE="true", + # use ClientOptions.api_endpoint as the api endpoint regardless. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}): + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ): + options = client_options.ClientOptions( + client_cert_source=mock_client_cert_source, api_endpoint=api_override + ) + client = client_class( + client_options=options, + credentials=ga_credentials.AnonymousCredentials(), + ) + assert client.api_endpoint == api_override + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == default_endpoint + + # If ClientOptions.api_endpoint is not set and GOOGLE_API_USE_MTLS_ENDPOINT="always", + # use the DEFAULT_MTLS_ENDPOINT as the api endpoint. + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}): + client = client_class(credentials=ga_credentials.AnonymousCredentials()) + assert client.api_endpoint == client_class.DEFAULT_MTLS_ENDPOINT + + # If ClientOptions.api_endpoint is not set, GOOGLE_API_USE_MTLS_ENDPOINT="auto" (default), + # GOOGLE_API_USE_CLIENT_CERTIFICATE="false" (default), default cert source doesn't exist, + # and ClientOptions.universe_domain="bar.com", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with universe domain as the api endpoint. + options = client_options.ClientOptions() + universe_exists = hasattr(options, "universe_domain") + if universe_exists: + options = client_options.ClientOptions(universe_domain=mock_universe) + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + else: + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == ( + mock_endpoint if universe_exists else default_endpoint + ) + assert client.universe_domain == ( + mock_universe if universe_exists else default_universe + ) + + # If ClientOptions does not have a universe domain attribute and GOOGLE_API_USE_MTLS_ENDPOINT="never", + # use the _DEFAULT_ENDPOINT_TEMPLATE populated with GDU as the api endpoint. + options = client_options.ClientOptions() + if hasattr(options, "universe_domain"): + delattr(options, "universe_domain") + with mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}): + client = client_class( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + assert client.api_endpoint == default_endpoint + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name", + [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc"), + ( + TextServiceAsyncClient, + transports.TextServiceGrpcAsyncIOTransport, + "grpc_asyncio", + ), + (TextServiceClient, transports.TextServiceRestTransport, "rest"), + ], +) +def test_text_service_client_client_options_scopes( + client_class, transport_class, transport_name +): + # Check the case scopes are provided. + options = client_options.ClientOptions( + scopes=["1", "2"], + ) + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=["1", "2"], + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", grpc_helpers), + ( + TextServiceAsyncClient, + transports.TextServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + (TextServiceClient, transports.TextServiceRestTransport, "rest", None), + ], +) +def test_text_service_client_client_options_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +def test_text_service_client_client_options_from_dict(): + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.text_service.transports.TextServiceGrpcTransport.__init__" + ) as grpc_transport: + grpc_transport.return_value = None + client = TextServiceClient(client_options={"api_endpoint": "squid.clam.whelk"}) + grpc_transport.assert_called_once_with( + credentials=None, + credentials_file=None, + host="squid.clam.whelk", + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + +@pytest.mark.parametrize( + "client_class,transport_class,transport_name,grpc_helpers", + [ + (TextServiceClient, transports.TextServiceGrpcTransport, "grpc", grpc_helpers), + ( + TextServiceAsyncClient, + transports.TextServiceGrpcAsyncIOTransport, + "grpc_asyncio", + grpc_helpers_async, + ), + ], +) +def test_text_service_client_create_channel_credentials_file( + client_class, transport_class, transport_name, grpc_helpers +): + # Check the case credentials file is provided. + options = client_options.ClientOptions(credentials_file="credentials.json") + + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options, transport=transport_name) + patched.assert_called_once_with( + credentials=None, + credentials_file="credentials.json", + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) + + # test that the credentials from file are saved and used as the credentials. + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel" + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + file_creds = ga_credentials.AnonymousCredentials() + load_creds.return_value = (file_creds, None) + adc.return_value = (creds, None) + client = client_class(client_options=options, transport=transport_name) + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=file_creds, + credentials_file=None, + quota_project_id=None, + default_scopes=(), + scopes=None, + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + text_service.GenerateTextRequest, + dict, + ], +) +def test_generate_text(request_type, transport: str = "grpc"): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.GenerateTextResponse() + response = client.generate_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = text_service.GenerateTextRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.GenerateTextResponse) + + +def test_generate_text_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = text_service.GenerateTextRequest( + model="model_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.generate_text(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.GenerateTextRequest( + model="model_value", + ) + + +def test_generate_text_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.generate_text in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.generate_text] = mock_rpc + request = {} + client.generate_text(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.generate_text(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_generate_text_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.generate_text + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.generate_text + ] = mock_rpc + + request = {} + await client.generate_text(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.generate_text(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_generate_text_async( + transport: str = "grpc_asyncio", request_type=text_service.GenerateTextRequest +): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.GenerateTextResponse() + ) + response = await client.generate_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = text_service.GenerateTextRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.GenerateTextResponse) + + +@pytest.mark.asyncio +async def test_generate_text_async_from_dict(): + await test_generate_text_async(request_type=dict) + + +def test_generate_text_field_headers(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.GenerateTextRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: + call.return_value = text_service.GenerateTextResponse() + client.generate_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_generate_text_field_headers_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.GenerateTextRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.GenerateTextResponse() + ) + await client.generate_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +def test_generate_text_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.GenerateTextResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.generate_text( + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].prompt + mock_val = text_service.TextPrompt(text="text_value") + assert arg == mock_val + assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) + arg = args[0].candidate_count + mock_val = 1573 + assert arg == mock_val + arg = args[0].max_output_tokens + mock_val = 1865 + assert arg == mock_val + assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) + arg = args[0].top_k + mock_val = 541 + assert arg == mock_val + + +def test_generate_text_flattened_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_text( + text_service.GenerateTextRequest(), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + + +@pytest.mark.asyncio +async def test_generate_text_flattened_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.GenerateTextResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.GenerateTextResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.generate_text( + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].prompt + mock_val = text_service.TextPrompt(text="text_value") + assert arg == mock_val + assert math.isclose(args[0].temperature, 0.1198, rel_tol=1e-6) + arg = args[0].candidate_count + mock_val = 1573 + assert arg == mock_val + arg = args[0].max_output_tokens + mock_val = 1865 + assert arg == mock_val + assert math.isclose(args[0].top_p, 0.546, rel_tol=1e-6) + arg = args[0].top_k + mock_val = 541 + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_generate_text_flattened_error_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.generate_text( + text_service.GenerateTextRequest(), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + + +@pytest.mark.parametrize( + "request_type", + [ + text_service.EmbedTextRequest, + dict, + ], +) +def test_embed_text(request_type, transport: str = "grpc"): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.EmbedTextResponse() + response = client.embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = text_service.EmbedTextRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.EmbedTextResponse) + + +def test_embed_text_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = text_service.EmbedTextRequest( + model="model_value", + text="text_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.embed_text(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.EmbedTextRequest( + model="model_value", + text="text_value", + ) + + +def test_embed_text_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.embed_text in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.embed_text] = mock_rpc + request = {} + client.embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.embed_text(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_embed_text_async_use_cached_wrapped_rpc(transport: str = "grpc_asyncio"): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.embed_text + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.embed_text + ] = mock_rpc + + request = {} + await client.embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.embed_text(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_embed_text_async( + transport: str = "grpc_asyncio", request_type=text_service.EmbedTextRequest +): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.EmbedTextResponse() + ) + response = await client.embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = text_service.EmbedTextRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.EmbedTextResponse) + + +@pytest.mark.asyncio +async def test_embed_text_async_from_dict(): + await test_embed_text_async(request_type=dict) + + +def test_embed_text_field_headers(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.EmbedTextRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: + call.return_value = text_service.EmbedTextResponse() + client.embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_embed_text_field_headers_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.EmbedTextRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.EmbedTextResponse() + ) + await client.embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +def test_embed_text_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.EmbedTextResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.embed_text( + model="model_value", + text="text_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].text + mock_val = "text_value" + assert arg == mock_val + + +def test_embed_text_flattened_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.embed_text( + text_service.EmbedTextRequest(), + model="model_value", + text="text_value", + ) + + +@pytest.mark.asyncio +async def test_embed_text_flattened_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.EmbedTextResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.EmbedTextResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.embed_text( + model="model_value", + text="text_value", + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].text + mock_val = "text_value" + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_embed_text_flattened_error_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.embed_text( + text_service.EmbedTextRequest(), + model="model_value", + text="text_value", + ) + + +@pytest.mark.parametrize( + "request_type", + [ + text_service.BatchEmbedTextRequest, + dict, + ], +) +def test_batch_embed_text(request_type, transport: str = "grpc"): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.BatchEmbedTextResponse() + response = client.batch_embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = text_service.BatchEmbedTextRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.BatchEmbedTextResponse) + + +def test_batch_embed_text_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = text_service.BatchEmbedTextRequest( + model="model_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.batch_embed_text(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.BatchEmbedTextRequest( + model="model_value", + ) + + +def test_batch_embed_text_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.batch_embed_text in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_embed_text + ] = mock_rpc + request = {} + client.batch_embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_embed_text(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_batch_embed_text_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.batch_embed_text + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.batch_embed_text + ] = mock_rpc + + request = {} + await client.batch_embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.batch_embed_text(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_batch_embed_text_async( + transport: str = "grpc_asyncio", request_type=text_service.BatchEmbedTextRequest +): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.BatchEmbedTextResponse() + ) + response = await client.batch_embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = text_service.BatchEmbedTextRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.BatchEmbedTextResponse) + + +@pytest.mark.asyncio +async def test_batch_embed_text_async_from_dict(): + await test_batch_embed_text_async(request_type=dict) + + +def test_batch_embed_text_field_headers(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.BatchEmbedTextRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: + call.return_value = text_service.BatchEmbedTextResponse() + client.batch_embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_batch_embed_text_field_headers_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.BatchEmbedTextRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.BatchEmbedTextResponse() + ) + await client.batch_embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +def test_batch_embed_text_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.BatchEmbedTextResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.batch_embed_text( + model="model_value", + texts=["texts_value"], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].texts + mock_val = ["texts_value"] + assert arg == mock_val + + +def test_batch_embed_text_flattened_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.batch_embed_text( + text_service.BatchEmbedTextRequest(), + model="model_value", + texts=["texts_value"], + ) + + +@pytest.mark.asyncio +async def test_batch_embed_text_flattened_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.BatchEmbedTextResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.BatchEmbedTextResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.batch_embed_text( + model="model_value", + texts=["texts_value"], + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].texts + mock_val = ["texts_value"] + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_batch_embed_text_flattened_error_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.batch_embed_text( + text_service.BatchEmbedTextRequest(), + model="model_value", + texts=["texts_value"], + ) + + +@pytest.mark.parametrize( + "request_type", + [ + text_service.CountTextTokensRequest, + dict, + ], +) +def test_count_text_tokens(request_type, transport: str = "grpc"): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.CountTextTokensResponse( + token_count=1193, + ) + response = client.count_text_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + request = text_service.CountTextTokensRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.CountTextTokensResponse) + assert response.token_count == 1193 + + +def test_count_text_tokens_non_empty_request_with_auto_populated_field(): + # This test is a coverage failsafe to make sure that UUID4 fields are + # automatically populated, according to AIP-4235, with non-empty requests. + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Populate all string fields in the request which are not UUID4 + # since we want to check that UUID4 are populated automatically + # if they meet the requirements of AIP 4235. + request = text_service.CountTextTokensRequest( + model="model_value", + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), "__call__" + ) as call: + call.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client.count_text_tokens(request=request) + call.assert_called() + _, args, _ = call.mock_calls[0] + assert args[0] == text_service.CountTextTokensRequest( + model="model_value", + ) + + +def test_count_text_tokens_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.count_text_tokens in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.count_text_tokens + ] = mock_rpc + request = {} + client.count_text_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.count_text_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_count_text_tokens_async_use_cached_wrapped_rpc( + transport: str = "grpc_asyncio", +): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method_async.wrap_method") as wrapper_fn: + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert ( + client._client._transport.count_text_tokens + in client._client._transport._wrapped_methods + ) + + # Replace cached wrapped function with mock + mock_rpc = mock.AsyncMock() + mock_rpc.return_value = mock.Mock() + client._client._transport._wrapped_methods[ + client._client._transport.count_text_tokens + ] = mock_rpc + + request = {} + await client.count_text_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + await client.count_text_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +@pytest.mark.asyncio +async def test_count_text_tokens_async( + transport: str = "grpc_asyncio", request_type=text_service.CountTextTokensRequest +): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = request_type() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.CountTextTokensResponse( + token_count=1193, + ) + ) + response = await client.count_text_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + request = text_service.CountTextTokensRequest() + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.CountTextTokensResponse) + assert response.token_count == 1193 + + +@pytest.mark.asyncio +async def test_count_text_tokens_async_from_dict(): + await test_count_text_tokens_async(request_type=dict) + + +def test_count_text_tokens_field_headers(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.CountTextTokensRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), "__call__" + ) as call: + call.return_value = text_service.CountTextTokensResponse() + client.count_text_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_count_text_tokens_field_headers_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = text_service.CountTextTokensRequest() + + request.model = "model_value" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), "__call__" + ) as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.CountTextTokensResponse() + ) + await client.count_text_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "model=model_value", + ) in kw["metadata"] + + +def test_count_text_tokens_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.CountTextTokensResponse() + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + client.count_text_tokens( + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].prompt + mock_val = text_service.TextPrompt(text="text_value") + assert arg == mock_val + + +def test_count_text_tokens_flattened_error(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.count_text_tokens( + text_service.CountTextTokensRequest(), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), + ) + + +@pytest.mark.asyncio +async def test_count_text_tokens_flattened_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = text_service.CountTextTokensResponse() + + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.CountTextTokensResponse() + ) + # Call the method with a truthy value for each flattened field, + # using the keyword arguments to the method. + response = await client.count_text_tokens( + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), + ) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(call.mock_calls) + _, args, _ = call.mock_calls[0] + arg = args[0].model + mock_val = "model_value" + assert arg == mock_val + arg = args[0].prompt + mock_val = text_service.TextPrompt(text="text_value") + assert arg == mock_val + + +@pytest.mark.asyncio +async def test_count_text_tokens_flattened_error_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + await client.count_text_tokens( + text_service.CountTextTokensRequest(), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), + ) + + +def test_generate_text_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.generate_text in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.generate_text] = mock_rpc + + request = {} + client.generate_text(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.generate_text(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_generate_text_rest_required_fields( + request_type=text_service.GenerateTextRequest, +): + transport_class = transports.TextServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).generate_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = "model_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).generate_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == "model_value" + + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = text_service.GenerateTextResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = text_service.GenerateTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.generate_text(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_generate_text_rest_unset_required_fields(): + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.generate_text._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "prompt", + ) + ) + ) + + +def test_generate_text_rest_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = text_service.GenerateTextResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"model": "models/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = text_service.GenerateTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.generate_text(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{model=models/*}:generateText" % client.transport._host, args[1] + ) + + +def test_generate_text_rest_flattened_error(transport: str = "rest"): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.generate_text( + text_service.GenerateTextRequest(), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), + temperature=0.1198, + candidate_count=1573, + max_output_tokens=1865, + top_p=0.546, + top_k=541, + ) + + +def test_embed_text_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.embed_text in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[client._transport.embed_text] = mock_rpc + + request = {} + client.embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.embed_text(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_embed_text_rest_required_fields(request_type=text_service.EmbedTextRequest): + transport_class = transports.TextServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).embed_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = "model_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).embed_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == "model_value" + + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = text_service.EmbedTextResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = text_service.EmbedTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.embed_text(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_embed_text_rest_unset_required_fields(): + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.embed_text._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("model",))) + + +def test_embed_text_rest_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = text_service.EmbedTextResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"model": "models/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + model="model_value", + text="text_value", + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = text_service.EmbedTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.embed_text(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{model=models/*}:embedText" % client.transport._host, args[1] + ) + + +def test_embed_text_rest_flattened_error(transport: str = "rest"): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.embed_text( + text_service.EmbedTextRequest(), + model="model_value", + text="text_value", + ) + + +def test_batch_embed_text_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.batch_embed_text in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.batch_embed_text + ] = mock_rpc + + request = {} + client.batch_embed_text(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.batch_embed_text(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_batch_embed_text_rest_required_fields( + request_type=text_service.BatchEmbedTextRequest, +): + transport_class = transports.TextServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).batch_embed_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = "model_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).batch_embed_text._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == "model_value" + + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = text_service.BatchEmbedTextResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = text_service.BatchEmbedTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.batch_embed_text(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_batch_embed_text_rest_unset_required_fields(): + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.batch_embed_text._get_unset_required_fields({}) + assert set(unset_fields) == (set(()) & set(("model",))) + + +def test_batch_embed_text_rest_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = text_service.BatchEmbedTextResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"model": "models/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + model="model_value", + texts=["texts_value"], + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = text_service.BatchEmbedTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.batch_embed_text(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{model=models/*}:batchEmbedText" % client.transport._host, + args[1], + ) + + +def test_batch_embed_text_rest_flattened_error(transport: str = "rest"): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.batch_embed_text( + text_service.BatchEmbedTextRequest(), + model="model_value", + texts=["texts_value"], + ) + + +def test_count_text_tokens_rest_use_cached_wrapped_rpc(): + # Clients should use _prep_wrapped_messages to create cached wrapped rpcs, + # instead of constructing them on each call + with mock.patch("google.api_core.gapic_v1.method.wrap_method") as wrapper_fn: + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Should wrap all calls on client creation + assert wrapper_fn.call_count > 0 + wrapper_fn.reset_mock() + + # Ensure method has been cached + assert client._transport.count_text_tokens in client._transport._wrapped_methods + + # Replace cached wrapped function with mock + mock_rpc = mock.Mock() + mock_rpc.return_value.name = ( + "foo" # operation_request.operation in compute client(s) expect a string. + ) + client._transport._wrapped_methods[ + client._transport.count_text_tokens + ] = mock_rpc + + request = {} + client.count_text_tokens(request) + + # Establish that the underlying gRPC stub method was called. + assert mock_rpc.call_count == 1 + + client.count_text_tokens(request) + + # Establish that a new wrapper was not created for this call + assert wrapper_fn.call_count == 0 + assert mock_rpc.call_count == 2 + + +def test_count_text_tokens_rest_required_fields( + request_type=text_service.CountTextTokensRequest, +): + transport_class = transports.TextServiceRestTransport + + request_init = {} + request_init["model"] = "" + request = request_type(**request_init) + pb_request = request_type.pb(request) + jsonified_request = json.loads( + json_format.MessageToJson(pb_request, use_integers_for_enums=False) + ) + + # verify fields with default values are dropped + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).count_text_tokens._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with default values are now present + + jsonified_request["model"] = "model_value" + + unset_fields = transport_class( + credentials=ga_credentials.AnonymousCredentials() + ).count_text_tokens._get_unset_required_fields(jsonified_request) + jsonified_request.update(unset_fields) + + # verify required fields with non-default values are left alone + assert "model" in jsonified_request + assert jsonified_request["model"] == "model_value" + + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type(**request_init) + + # Designate an appropriate value for the returned response. + return_value = text_service.CountTextTokensResponse() + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # We need to mock transcode() because providing default values + # for required fields will fail the real version if the http_options + # expect actual values for those fields. + with mock.patch.object(path_template, "transcode") as transcode: + # A uri without fields and an empty body will force all the + # request fields to show up in the query_params. + pb_request = request_type.pb(request) + transcode_result = { + "uri": "v1/sample_method", + "method": "post", + "query_params": pb_request, + } + transcode_result["body"] = pb_request + transcode.return_value = transcode_result + + response_value = Response() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = text_service.CountTextTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.count_text_tokens(request) + + expected_params = [("$alt", "json;enum-encoding=int")] + actual_params = req.call_args.kwargs["params"] + assert expected_params == actual_params + + +def test_count_text_tokens_rest_unset_required_fields(): + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials + ) + + unset_fields = transport.count_text_tokens._get_unset_required_fields({}) + assert set(unset_fields) == ( + set(()) + & set( + ( + "model", + "prompt", + ) + ) + ) + + +def test_count_text_tokens_rest_flattened(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = text_service.CountTextTokensResponse() + + # get arguments that satisfy an http rule for this method + sample_request = {"model": "models/sample1"} + + # get truthy value for each flattened field + mock_args = dict( + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), + ) + mock_args.update(sample_request) + + # Wrap the value into a proper Response obj + response_value = Response() + response_value.status_code = 200 + # Convert return value to protobuf type + return_value = text_service.CountTextTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value._content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + client.count_text_tokens(**mock_args) + + # Establish that the underlying call was made with the expected + # request object values. + assert len(req.mock_calls) == 1 + _, args, _ = req.mock_calls[0] + assert path_template.validate( + "%s/v1alpha/{model=models/*}:countTextTokens" % client.transport._host, + args[1], + ) + + +def test_count_text_tokens_rest_flattened_error(transport: str = "rest"): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Attempting to call a method with both a request object and flattened + # fields is an error. + with pytest.raises(ValueError): + client.count_text_tokens( + text_service.CountTextTokensRequest(), + model="model_value", + prompt=text_service.TextPrompt(text="text_value"), + ) + + +def test_credentials_transport_error(): + # It is an error to provide credentials and a transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # It is an error to provide a credentials file and a transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = TextServiceClient( + client_options={"credentials_file": "credentials.json"}, + transport=transport, + ) + + # It is an error to provide an api_key and a transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = TextServiceClient( + client_options=options, + transport=transport, + ) + + # It is an error to provide an api_key and a credential. + options = client_options.ClientOptions() + options.api_key = "api_key" + with pytest.raises(ValueError): + client = TextServiceClient( + client_options=options, credentials=ga_credentials.AnonymousCredentials() + ) + + # It is an error to provide scopes and a transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + with pytest.raises(ValueError): + client = TextServiceClient( + client_options={"scopes": ["1", "2"]}, + transport=transport, + ) + + +def test_transport_instance(): + # A client may be instantiated with a custom transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + client = TextServiceClient(transport=transport) + assert client.transport is transport + + +def test_transport_get_channel(): + # A client may be instantiated with a custom transport instance. + transport = transports.TextServiceGrpcTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + transport = transports.TextServiceGrpcAsyncIOTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + channel = transport.grpc_channel + assert channel + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.TextServiceGrpcTransport, + transports.TextServiceGrpcAsyncIOTransport, + transports.TextServiceRestTransport, + ], +) +def test_transport_adc(transport_class): + # Test default credentials are used if not provided. + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class() + adc.assert_called_once() + + +def test_transport_kind_grpc(): + transport = TextServiceClient.get_transport_class("grpc")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "grpc" + + +def test_initialize_client_w_grpc(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_generate_text_empty_call_grpc(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: + call.return_value = text_service.GenerateTextResponse() + client.generate_text(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = text_service.GenerateTextRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_embed_text_empty_call_grpc(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: + call.return_value = text_service.EmbedTextResponse() + client.embed_text(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = text_service.EmbedTextRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_batch_embed_text_empty_call_grpc(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: + call.return_value = text_service.BatchEmbedTextResponse() + client.batch_embed_text(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = text_service.BatchEmbedTextRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_count_text_tokens_empty_call_grpc(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="grpc", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), "__call__" + ) as call: + call.return_value = text_service.CountTextTokensResponse() + client.count_text_tokens(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = text_service.CountTextTokensRequest() + + assert args[0] == request_msg + + +def test_transport_kind_grpc_asyncio(): + transport = TextServiceAsyncClient.get_transport_class("grpc_asyncio")( + credentials=async_anonymous_credentials() + ) + assert transport.kind == "grpc_asyncio" + + +def test_initialize_client_w_grpc_asyncio(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_generate_text_empty_call_grpc_asyncio(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.GenerateTextResponse() + ) + await client.generate_text(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = text_service.GenerateTextRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_embed_text_empty_call_grpc_asyncio(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.EmbedTextResponse() + ) + await client.embed_text(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = text_service.EmbedTextRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_batch_embed_text_empty_call_grpc_asyncio(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.BatchEmbedTextResponse() + ) + await client.batch_embed_text(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = text_service.BatchEmbedTextRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +@pytest.mark.asyncio +async def test_count_text_tokens_empty_call_grpc_asyncio(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport="grpc_asyncio", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), "__call__" + ) as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + text_service.CountTextTokensResponse( + token_count=1193, + ) + ) + await client.count_text_tokens(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = text_service.CountTextTokensRequest() + + assert args[0] == request_msg + + +def test_transport_kind_rest(): + transport = TextServiceClient.get_transport_class("rest")( + credentials=ga_credentials.AnonymousCredentials() + ) + assert transport.kind == "rest" + + +def test_generate_text_rest_bad_request(request_type=text_service.GenerateTextRequest): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.generate_text(request) + + +@pytest.mark.parametrize( + "request_type", + [ + text_service.GenerateTextRequest, + dict, + ], +) +def test_generate_text_rest_call_success(request_type): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = text_service.GenerateTextResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = text_service.GenerateTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.generate_text(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.GenerateTextResponse) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_generate_text_rest_interceptors(null_interceptor): + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.TextServiceRestInterceptor(), + ) + client = TextServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.TextServiceRestInterceptor, "post_generate_text" + ) as post, mock.patch.object( + transports.TextServiceRestInterceptor, "pre_generate_text" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = text_service.GenerateTextRequest.pb( + text_service.GenerateTextRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = text_service.GenerateTextResponse.to_json( + text_service.GenerateTextResponse() + ) + req.return_value.content = return_value + + request = text_service.GenerateTextRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = text_service.GenerateTextResponse() + + client.generate_text( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_embed_text_rest_bad_request(request_type=text_service.EmbedTextRequest): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.embed_text(request) + + +@pytest.mark.parametrize( + "request_type", + [ + text_service.EmbedTextRequest, + dict, + ], +) +def test_embed_text_rest_call_success(request_type): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = text_service.EmbedTextResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = text_service.EmbedTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.embed_text(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.EmbedTextResponse) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_embed_text_rest_interceptors(null_interceptor): + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.TextServiceRestInterceptor(), + ) + client = TextServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.TextServiceRestInterceptor, "post_embed_text" + ) as post, mock.patch.object( + transports.TextServiceRestInterceptor, "pre_embed_text" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = text_service.EmbedTextRequest.pb(text_service.EmbedTextRequest()) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = text_service.EmbedTextResponse.to_json( + text_service.EmbedTextResponse() + ) + req.return_value.content = return_value + + request = text_service.EmbedTextRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = text_service.EmbedTextResponse() + + client.embed_text( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_batch_embed_text_rest_bad_request( + request_type=text_service.BatchEmbedTextRequest, +): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.batch_embed_text(request) + + +@pytest.mark.parametrize( + "request_type", + [ + text_service.BatchEmbedTextRequest, + dict, + ], +) +def test_batch_embed_text_rest_call_success(request_type): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = text_service.BatchEmbedTextResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = text_service.BatchEmbedTextResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.batch_embed_text(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.BatchEmbedTextResponse) + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_batch_embed_text_rest_interceptors(null_interceptor): + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.TextServiceRestInterceptor(), + ) + client = TextServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.TextServiceRestInterceptor, "post_batch_embed_text" + ) as post, mock.patch.object( + transports.TextServiceRestInterceptor, "pre_batch_embed_text" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = text_service.BatchEmbedTextRequest.pb( + text_service.BatchEmbedTextRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = text_service.BatchEmbedTextResponse.to_json( + text_service.BatchEmbedTextResponse() + ) + req.return_value.content = return_value + + request = text_service.BatchEmbedTextRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = text_service.BatchEmbedTextResponse() + + client.batch_embed_text( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_count_text_tokens_rest_bad_request( + request_type=text_service.CountTextTokensRequest, +): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = mock.Mock() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = mock.Mock() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.count_text_tokens(request) + + +@pytest.mark.parametrize( + "request_type", + [ + text_service.CountTextTokensRequest, + dict, + ], +) +def test_count_text_tokens_rest_call_success(request_type): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + + # send a request that will satisfy transcoding + request_init = {"model": "models/sample1"} + request = request_type(**request_init) + + # Mock the http request call within the method and fake a response. + with mock.patch.object(type(client.transport._session), "request") as req: + # Designate an appropriate value for the returned response. + return_value = text_service.CountTextTokensResponse( + token_count=1193, + ) + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + + # Convert return value to protobuf type + return_value = text_service.CountTextTokensResponse.pb(return_value) + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + response = client.count_text_tokens(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, text_service.CountTextTokensResponse) + assert response.token_count == 1193 + + +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_count_text_tokens_rest_interceptors(null_interceptor): + transport = transports.TextServiceRestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None + if null_interceptor + else transports.TextServiceRestInterceptor(), + ) + client = TextServiceClient(transport=transport) + + with mock.patch.object( + type(client.transport._session), "request" + ) as req, mock.patch.object( + path_template, "transcode" + ) as transcode, mock.patch.object( + transports.TextServiceRestInterceptor, "post_count_text_tokens" + ) as post, mock.patch.object( + transports.TextServiceRestInterceptor, "pre_count_text_tokens" + ) as pre: + pre.assert_not_called() + post.assert_not_called() + pb_message = text_service.CountTextTokensRequest.pb( + text_service.CountTextTokensRequest() + ) + transcode.return_value = { + "method": "post", + "uri": "my_uri", + "body": pb_message, + "query_params": pb_message, + } + + req.return_value = mock.Mock() + req.return_value.status_code = 200 + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + return_value = text_service.CountTextTokensResponse.to_json( + text_service.CountTextTokensResponse() + ) + req.return_value.content = return_value + + request = text_service.CountTextTokensRequest() + metadata = [ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + post.return_value = text_service.CountTextTokensResponse() + + client.count_text_tokens( + request, + metadata=[ + ("key", "val"), + ("cephalopod", "squid"), + ], + ) + + pre.assert_called_once() + post.assert_called_once() + + +def test_get_operation_rest_bad_request( + request_type=operations_pb2.GetOperationRequest, +): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict( + {"name": "tunedModels/sample1/operations/sample2"}, request + ) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.get_operation(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.GetOperationRequest, + dict, + ], +) +def test_get_operation_rest(request_type): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1/operations/sample2"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.Operation() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.get_operation(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_list_operations_rest_bad_request( + request_type=operations_pb2.ListOperationsRequest, +): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + request = request_type() + request = json_format.ParseDict({"name": "tunedModels/sample1"}, request) + + # Mock the http request call within the method and fake a BadRequest error. + with mock.patch.object(Session, "request") as req, pytest.raises( + core_exceptions.BadRequest + ): + # Wrap the value into a proper Response obj + response_value = Response() + json_return_value = "" + response_value.json = mock.Mock(return_value={}) + response_value.status_code = 400 + response_value.request = Request() + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + client.list_operations(request) + + +@pytest.mark.parametrize( + "request_type", + [ + operations_pb2.ListOperationsRequest, + dict, + ], +) +def test_list_operations_rest(request_type): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + request_init = {"name": "tunedModels/sample1"} + request = request_type(**request_init) + # Mock the http request call within the method and fake a response. + with mock.patch.object(Session, "request") as req: + # Designate an appropriate value for the returned response. + return_value = operations_pb2.ListOperationsResponse() + + # Wrap the value into a proper Response obj + response_value = mock.Mock() + response_value.status_code = 200 + json_return_value = json_format.MessageToJson(return_value) + response_value.content = json_return_value.encode("UTF-8") + + req.return_value = response_value + req.return_value.headers = {"header-1": "value-1", "header-2": "value-2"} + + response = client.list_operations(request) + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_initialize_client_w_rest(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + assert client is not None + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_generate_text_empty_call_rest(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.generate_text), "__call__") as call: + client.generate_text(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = text_service.GenerateTextRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_embed_text_empty_call_rest(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.embed_text), "__call__") as call: + client.embed_text(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = text_service.EmbedTextRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_batch_embed_text_empty_call_rest(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object(type(client.transport.batch_embed_text), "__call__") as call: + client.batch_embed_text(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = text_service.BatchEmbedTextRequest() + + assert args[0] == request_msg + + +# This test is a coverage failsafe to make sure that totally empty calls, +# i.e. request == None and no flattened fields passed, work. +def test_count_text_tokens_empty_call_rest(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport="rest", + ) + + # Mock the actual call, and fake the request. + with mock.patch.object( + type(client.transport.count_text_tokens), "__call__" + ) as call: + client.count_text_tokens(request=None) + + # Establish that the underlying stub method was called. + call.assert_called() + _, args, _ = call.mock_calls[0] + request_msg = text_service.CountTextTokensRequest() + + assert args[0] == request_msg + + +def test_transport_grpc_default(): + # A client should use the gRPC transport by default. + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + assert isinstance( + client.transport, + transports.TextServiceGrpcTransport, + ) + + +def test_text_service_base_transport_error(): + # Passing both a credentials object and credentials_file should raise an error + with pytest.raises(core_exceptions.DuplicateCredentialArgs): + transport = transports.TextServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + credentials_file="credentials.json", + ) + + +def test_text_service_base_transport(): + # Instantiate the base transport. + with mock.patch( + "google.ai.generativelanguage_v1alpha.services.text_service.transports.TextServiceTransport.__init__" + ) as Transport: + Transport.return_value = None + transport = transports.TextServiceTransport( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Every method on the transport should just blindly + # raise NotImplementedError. + methods = ( + "generate_text", + "embed_text", + "batch_embed_text", + "count_text_tokens", + "get_operation", + "list_operations", + ) + for method in methods: + with pytest.raises(NotImplementedError): + getattr(transport, method)(request=object()) + + with pytest.raises(NotImplementedError): + transport.close() + + # Catch all for all remaining methods and properties + remainder = [ + "kind", + ] + for r in remainder: + with pytest.raises(NotImplementedError): + getattr(transport, r)() + + +def test_text_service_base_transport_with_credentials_file(): + # Instantiate the base transport with a credentials file + with mock.patch.object( + google.auth, "load_credentials_from_file", autospec=True + ) as load_creds, mock.patch( + "google.ai.generativelanguage_v1alpha.services.text_service.transports.TextServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + load_creds.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.TextServiceTransport( + credentials_file="credentials.json", + quota_project_id="octopus", + ) + load_creds.assert_called_once_with( + "credentials.json", + scopes=None, + default_scopes=(), + quota_project_id="octopus", + ) + + +def test_text_service_base_transport_with_adc(): + # Test the default credentials are used if credentials and credentials_file are None. + with mock.patch.object(google.auth, "default", autospec=True) as adc, mock.patch( + "google.ai.generativelanguage_v1alpha.services.text_service.transports.TextServiceTransport._prep_wrapped_messages" + ) as Transport: + Transport.return_value = None + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport = transports.TextServiceTransport() + adc.assert_called_once() + + +def test_text_service_auth_adc(): + # If no credentials are provided, we should use ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + TextServiceClient() + adc.assert_called_once_with( + scopes=None, + default_scopes=(), + quota_project_id=None, + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.TextServiceGrpcTransport, + transports.TextServiceGrpcAsyncIOTransport, + ], +) +def test_text_service_transport_auth_adc(transport_class): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object(google.auth, "default", autospec=True) as adc: + adc.return_value = (ga_credentials.AnonymousCredentials(), None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + adc.assert_called_once_with( + scopes=["1", "2"], + default_scopes=(), + quota_project_id="octopus", + ) + + +@pytest.mark.parametrize( + "transport_class", + [ + transports.TextServiceGrpcTransport, + transports.TextServiceGrpcAsyncIOTransport, + transports.TextServiceRestTransport, + ], +) +def test_text_service_transport_auth_gdch_credentials(transport_class): + host = "https://language.com" + api_audience_tests = [None, "https://language2.com"] + api_audience_expect = [host, "https://language2.com"] + for t, e in zip(api_audience_tests, api_audience_expect): + with mock.patch.object(google.auth, "default", autospec=True) as adc: + gdch_mock = mock.MagicMock() + type(gdch_mock).with_gdch_audience = mock.PropertyMock( + return_value=gdch_mock + ) + adc.return_value = (gdch_mock, None) + transport_class(host=host, api_audience=t) + gdch_mock.with_gdch_audience.assert_called_once_with(e) + + +@pytest.mark.parametrize( + "transport_class,grpc_helpers", + [ + (transports.TextServiceGrpcTransport, grpc_helpers), + (transports.TextServiceGrpcAsyncIOTransport, grpc_helpers_async), + ], +) +def test_text_service_transport_create_channel(transport_class, grpc_helpers): + # If credentials and host are not provided, the transport class should use + # ADC credentials. + with mock.patch.object( + google.auth, "default", autospec=True + ) as adc, mock.patch.object( + grpc_helpers, "create_channel", autospec=True + ) as create_channel: + creds = ga_credentials.AnonymousCredentials() + adc.return_value = (creds, None) + transport_class(quota_project_id="octopus", scopes=["1", "2"]) + + create_channel.assert_called_with( + "generativelanguage.googleapis.com:443", + credentials=creds, + credentials_file=None, + quota_project_id="octopus", + default_scopes=(), + scopes=["1", "2"], + default_host="generativelanguage.googleapis.com", + ssl_credentials=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + +@pytest.mark.parametrize( + "transport_class", + [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport], +) +def test_text_service_grpc_transport_client_cert_source_for_mtls(transport_class): + cred = ga_credentials.AnonymousCredentials() + + # Check ssl_channel_credentials is used if provided. + with mock.patch.object(transport_class, "create_channel") as mock_create_channel: + mock_ssl_channel_creds = mock.Mock() + transport_class( + host="squid.clam.whelk", + credentials=cred, + ssl_channel_credentials=mock_ssl_channel_creds, + ) + mock_create_channel.assert_called_once_with( + "squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_channel_creds, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + + # Check if ssl_channel_credentials is not provided, then client_cert_source_for_mtls + # is used. + with mock.patch.object(transport_class, "create_channel", return_value=mock.Mock()): + with mock.patch("grpc.ssl_channel_credentials") as mock_ssl_cred: + transport_class( + credentials=cred, + client_cert_source_for_mtls=client_cert_source_callback, + ) + expected_cert, expected_key = client_cert_source_callback() + mock_ssl_cred.assert_called_once_with( + certificate_chain=expected_cert, private_key=expected_key + ) + + +def test_text_service_http_transport_client_cert_source_for_mtls(): + cred = ga_credentials.AnonymousCredentials() + with mock.patch( + "google.auth.transport.requests.AuthorizedSession.configure_mtls_channel" + ) as mock_configure_mtls_channel: + transports.TextServiceRestTransport( + credentials=cred, client_cert_source_for_mtls=client_cert_source_callback + ) + mock_configure_mtls_channel.assert_called_once_with(client_cert_source_callback) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_text_service_host_no_port(transport_name): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:443" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "grpc", + "grpc_asyncio", + "rest", + ], +) +def test_text_service_host_with_port(transport_name): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_options=client_options.ClientOptions( + api_endpoint="generativelanguage.googleapis.com:8000" + ), + transport=transport_name, + ) + assert client.transport._host == ( + "generativelanguage.googleapis.com:8000" + if transport_name in ["grpc", "grpc_asyncio"] + else "https://generativelanguage.googleapis.com:8000" + ) + + +@pytest.mark.parametrize( + "transport_name", + [ + "rest", + ], +) +def test_text_service_client_transport_session_collision(transport_name): + creds1 = ga_credentials.AnonymousCredentials() + creds2 = ga_credentials.AnonymousCredentials() + client1 = TextServiceClient( + credentials=creds1, + transport=transport_name, + ) + client2 = TextServiceClient( + credentials=creds2, + transport=transport_name, + ) + session1 = client1.transport.generate_text._session + session2 = client2.transport.generate_text._session + assert session1 != session2 + session1 = client1.transport.embed_text._session + session2 = client2.transport.embed_text._session + assert session1 != session2 + session1 = client1.transport.batch_embed_text._session + session2 = client2.transport.batch_embed_text._session + assert session1 != session2 + session1 = client1.transport.count_text_tokens._session + session2 = client2.transport.count_text_tokens._session + assert session1 != session2 + + +def test_text_service_grpc_transport_channel(): + channel = grpc.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.TextServiceGrpcTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +def test_text_service_grpc_asyncio_transport_channel(): + channel = aio.secure_channel("http://localhost/", grpc.local_channel_credentials()) + + # Check that channel is used if provided. + transport = transports.TextServiceGrpcAsyncIOTransport( + host="squid.clam.whelk", + channel=channel, + ) + assert transport.grpc_channel == channel + assert transport._host == "squid.clam.whelk:443" + assert transport._ssl_channel_credentials == None + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport], +) +def test_text_service_transport_channel_mtls_with_client_cert_source(transport_class): + with mock.patch( + "grpc.ssl_channel_credentials", autospec=True + ) as grpc_ssl_channel_cred: + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_ssl_cred = mock.Mock() + grpc_ssl_channel_cred.return_value = mock_ssl_cred + + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + + cred = ga_credentials.AnonymousCredentials() + with pytest.warns(DeprecationWarning): + with mock.patch.object(google.auth, "default") as adc: + adc.return_value = (cred, None) + transport = transport_class( + host="squid.clam.whelk", + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=client_cert_source_callback, + ) + adc.assert_called_once() + + grpc_ssl_channel_cred.assert_called_once_with( + certificate_chain=b"cert bytes", private_key=b"key bytes" + ) + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + assert transport._ssl_channel_credentials == mock_ssl_cred + + +# Remove this test when deprecated arguments (api_mtls_endpoint, client_cert_source) are +# removed from grpc/grpc_asyncio transport constructor. +@pytest.mark.parametrize( + "transport_class", + [transports.TextServiceGrpcTransport, transports.TextServiceGrpcAsyncIOTransport], +) +def test_text_service_transport_channel_mtls_with_adc(transport_class): + mock_ssl_cred = mock.Mock() + with mock.patch.multiple( + "google.auth.transport.grpc.SslCredentials", + __init__=mock.Mock(return_value=None), + ssl_credentials=mock.PropertyMock(return_value=mock_ssl_cred), + ): + with mock.patch.object( + transport_class, "create_channel" + ) as grpc_create_channel: + mock_grpc_channel = mock.Mock() + grpc_create_channel.return_value = mock_grpc_channel + mock_cred = mock.Mock() + + with pytest.warns(DeprecationWarning): + transport = transport_class( + host="squid.clam.whelk", + credentials=mock_cred, + api_mtls_endpoint="mtls.squid.clam.whelk", + client_cert_source=None, + ) + + grpc_create_channel.assert_called_once_with( + "mtls.squid.clam.whelk:443", + credentials=mock_cred, + credentials_file=None, + scopes=None, + ssl_credentials=mock_ssl_cred, + quota_project_id=None, + options=[ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ], + ) + assert transport.grpc_channel == mock_grpc_channel + + +def test_model_path(): + model = "squid" + expected = "models/{model}".format( + model=model, + ) + actual = TextServiceClient.model_path(model) + assert expected == actual + + +def test_parse_model_path(): + expected = { + "model": "clam", + } + path = TextServiceClient.model_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_model_path(path) + assert expected == actual + + +def test_common_billing_account_path(): + billing_account = "whelk" + expected = "billingAccounts/{billing_account}".format( + billing_account=billing_account, + ) + actual = TextServiceClient.common_billing_account_path(billing_account) + assert expected == actual + + +def test_parse_common_billing_account_path(): + expected = { + "billing_account": "octopus", + } + path = TextServiceClient.common_billing_account_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_common_billing_account_path(path) + assert expected == actual + + +def test_common_folder_path(): + folder = "oyster" + expected = "folders/{folder}".format( + folder=folder, + ) + actual = TextServiceClient.common_folder_path(folder) + assert expected == actual + + +def test_parse_common_folder_path(): + expected = { + "folder": "nudibranch", + } + path = TextServiceClient.common_folder_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_common_folder_path(path) + assert expected == actual + + +def test_common_organization_path(): + organization = "cuttlefish" + expected = "organizations/{organization}".format( + organization=organization, + ) + actual = TextServiceClient.common_organization_path(organization) + assert expected == actual + + +def test_parse_common_organization_path(): + expected = { + "organization": "mussel", + } + path = TextServiceClient.common_organization_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_common_organization_path(path) + assert expected == actual + + +def test_common_project_path(): + project = "winkle" + expected = "projects/{project}".format( + project=project, + ) + actual = TextServiceClient.common_project_path(project) + assert expected == actual + + +def test_parse_common_project_path(): + expected = { + "project": "nautilus", + } + path = TextServiceClient.common_project_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_common_project_path(path) + assert expected == actual + + +def test_common_location_path(): + project = "scallop" + location = "abalone" + expected = "projects/{project}/locations/{location}".format( + project=project, + location=location, + ) + actual = TextServiceClient.common_location_path(project, location) + assert expected == actual + + +def test_parse_common_location_path(): + expected = { + "project": "squid", + "location": "clam", + } + path = TextServiceClient.common_location_path(**expected) + + # Check that the path construction is reversible. + actual = TextServiceClient.parse_common_location_path(path) + assert expected == actual + + +def test_client_with_default_client_info(): + client_info = gapic_v1.client_info.ClientInfo() + + with mock.patch.object( + transports.TextServiceTransport, "_prep_wrapped_messages" + ) as prep: + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + with mock.patch.object( + transports.TextServiceTransport, "_prep_wrapped_messages" + ) as prep: + transport_class = TextServiceClient.get_transport_class() + transport = transport_class( + credentials=ga_credentials.AnonymousCredentials(), + client_info=client_info, + ) + prep.assert_called_once_with(client_info) + + +def test_get_operation(transport: str = "grpc"): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + response = client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +@pytest.mark.asyncio +async def test_get_operation_async(transport: str = "grpc_asyncio"): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.GetOperationRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.Operation) + + +def test_get_operation_field_headers(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = operations_pb2.Operation() + + client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_get_operation_field_headers_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.GetOperationRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + await client.get_operation(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_get_operation_from_dict(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.Operation() + + response = client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_get_operation_from_dict_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.get_operation), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.Operation() + ) + response = await client.get_operation( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_list_operations(transport: str = "grpc"): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + response = client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +@pytest.mark.asyncio +async def test_list_operations_async(transport: str = "grpc_asyncio"): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + transport=transport, + ) + + # Everything is optional in proto3 as far as the runtime is concerned, + # and we are mocking out the actual API, so just send an empty request. + request = operations_pb2.ListOperationsRequest() + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the response is the type that we expect. + assert isinstance(response, operations_pb2.ListOperationsResponse) + + +def test_list_operations_field_headers(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = operations_pb2.ListOperationsResponse() + + client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +@pytest.mark.asyncio +async def test_list_operations_field_headers_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + + # Any value that is part of the HTTP/1.1 URI should be sent as + # a field header. Set these to a non-empty value. + request = operations_pb2.ListOperationsRequest() + request.name = "locations" + + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + await client.list_operations(request) + # Establish that the underlying gRPC stub method was called. + assert len(call.mock_calls) == 1 + _, args, _ = call.mock_calls[0] + assert args[0] == request + + # Establish that the field header was sent. + _, _, kw = call.mock_calls[0] + assert ( + "x-goog-request-params", + "name=locations", + ) in kw["metadata"] + + +def test_list_operations_from_dict(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = operations_pb2.ListOperationsResponse() + + response = client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +@pytest.mark.asyncio +async def test_list_operations_from_dict_async(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), + ) + # Mock the actual call within the gRPC stub, and fake the request. + with mock.patch.object(type(client.transport.list_operations), "__call__") as call: + # Designate an appropriate return value for the call. + call.return_value = grpc_helpers_async.FakeUnaryUnaryCall( + operations_pb2.ListOperationsResponse() + ) + response = await client.list_operations( + request={ + "name": "locations", + } + ) + call.assert_called() + + +def test_transport_close_grpc(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="grpc" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +@pytest.mark.asyncio +async def test_transport_close_grpc_asyncio(): + client = TextServiceAsyncClient( + credentials=async_anonymous_credentials(), transport="grpc_asyncio" + ) + with mock.patch.object( + type(getattr(client.transport, "_grpc_channel")), "close" + ) as close: + async with client: + close.assert_not_called() + close.assert_called_once() + + +def test_transport_close_rest(): + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport="rest" + ) + with mock.patch.object( + type(getattr(client.transport, "_session")), "close" + ) as close: + with client: + close.assert_not_called() + close.assert_called_once() + + +def test_client_ctx(): + transports = [ + "rest", + "grpc", + ] + for transport in transports: + client = TextServiceClient( + credentials=ga_credentials.AnonymousCredentials(), transport=transport + ) + # Test client calls underlying transport. + with mock.patch.object(type(client.transport), "close") as close: + close.assert_not_called() + with client: + pass + close.assert_called() + + +@pytest.mark.parametrize( + "client_class,transport_class", + [ + (TextServiceClient, transports.TextServiceGrpcTransport), + (TextServiceAsyncClient, transports.TextServiceGrpcAsyncIOTransport), + ], +) +def test_api_key_credentials(client_class, transport_class): + with mock.patch.object( + google.auth._default, "get_api_key_credentials", create=True + ) as get_api_key_credentials: + mock_cred = mock.Mock() + get_api_key_credentials.return_value = mock_cred + options = client_options.ClientOptions() + options.api_key = "api_key" + with mock.patch.object(transport_class, "__init__") as patched: + patched.return_value = None + client = client_class(client_options=options) + patched.assert_called_once_with( + credentials=mock_cred, + credentials_file=None, + host=client._DEFAULT_ENDPOINT_TEMPLATE.format( + UNIVERSE_DOMAIN=client._DEFAULT_UNIVERSE + ), + scopes=None, + client_cert_source_for_mtls=None, + quota_project_id=None, + client_info=transports.base.DEFAULT_CLIENT_INFO, + always_use_jwt_access=True, + api_audience=None, + ) diff --git a/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta/test_cache_service.py b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta/test_cache_service.py index 94bb3702eed2..259ecb2f4470 100644 --- a/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta/test_cache_service.py +++ b/packages/google-ai-generativelanguage/tests/unit/gapic/generativelanguage_v1beta/test_cache_service.py @@ -4172,8 +4172,16 @@ def test_create_cached_content_rest_call_success(request_type): "mime_type": "mime_type_value", "data": b"data_blob", }, - "function_call": {"name": "name_value", "args": {"fields": {}}}, - "function_response": {"name": "name_value", "response": {}}, + "function_call": { + "id": "id_value", + "name": "name_value", + "args": {"fields": {}}, + }, + "function_response": { + "id": "id_value", + "name": "name_value", + "response": {}, + }, "file_data": { "mime_type": "mime_type_value", "file_uri": "file_uri_value", @@ -4203,12 +4211,14 @@ def test_create_cached_content_rest_call_success(request_type): "properties": {}, "required": ["required_value1", "required_value2"], }, + "response": {}, } ], "google_search_retrieval": { "dynamic_retrieval_config": {"mode": 1, "dynamic_threshold": 0.1809} }, "code_execution": {}, + "google_search": {}, } ], "tool_config": { @@ -4561,8 +4571,16 @@ def test_update_cached_content_rest_call_success(request_type): "mime_type": "mime_type_value", "data": b"data_blob", }, - "function_call": {"name": "name_value", "args": {"fields": {}}}, - "function_response": {"name": "name_value", "response": {}}, + "function_call": { + "id": "id_value", + "name": "name_value", + "args": {"fields": {}}, + }, + "function_response": { + "id": "id_value", + "name": "name_value", + "response": {}, + }, "file_data": { "mime_type": "mime_type_value", "file_uri": "file_uri_value", @@ -4592,12 +4610,14 @@ def test_update_cached_content_rest_call_success(request_type): "properties": {}, "required": ["required_value1", "required_value2"], }, + "response": {}, } ], "google_search_retrieval": { "dynamic_retrieval_config": {"mode": 1, "dynamic_threshold": 0.1809} }, "code_execution": {}, + "google_search": {}, } ], "tool_config": {