diff --git a/pubsub/google/cloud/pubsub_v1/publisher/client.py b/pubsub/google/cloud/pubsub_v1/publisher/client.py index 76ceb470da24..b837de24c6f0 100644 --- a/pubsub/google/cloud/pubsub_v1/publisher/client.py +++ b/pubsub/google/cloud/pubsub_v1/publisher/client.py @@ -27,6 +27,7 @@ from google.cloud.pubsub_v1 import _gapic from google.cloud.pubsub_v1 import types from google.cloud.pubsub_v1.gapic import publisher_client +from google.cloud.pubsub_v1.gapic.transports import publisher_grpc_transport from google.cloud.pubsub_v1.publisher._batch import thread @@ -73,16 +74,22 @@ def __init__(self, batch_settings=(), **kwargs): # Use a custom channel. # We need this in order to set appropriate default message size and # keepalive options. - if "channel" not in kwargs: - kwargs["channel"] = grpc_helpers.create_channel( - credentials=kwargs.pop("credentials", None), - target=self.target, - scopes=publisher_client.PublisherClient._DEFAULT_SCOPES, - options={ - "grpc.max_send_message_length": -1, - "grpc.max_receive_message_length": -1, - }.items(), - ) + if "transport" not in kwargs: + channel = kwargs.pop("channel", None) + if channel is None: + channel = grpc_helpers.create_channel( + credentials=kwargs.pop("credentials", None), + target=self.target, + scopes=publisher_client.PublisherClient._DEFAULT_SCOPES, + options={ + "grpc.max_send_message_length": -1, + "grpc.max_receive_message_length": -1, + }.items(), + ) + # cannot pass both 'channel' and 'credentials' + kwargs.pop("credentials", None) + transport = publisher_grpc_transport.PublisherGrpcTransport(channel=channel) + kwargs["transport"] = transport # Add the metrics headers, and instantiate the underlying GAPIC # client. diff --git a/pubsub/google/cloud/pubsub_v1/subscriber/client.py b/pubsub/google/cloud/pubsub_v1/subscriber/client.py index b50a269e99f0..0540333ad8ea 100644 --- a/pubsub/google/cloud/pubsub_v1/subscriber/client.py +++ b/pubsub/google/cloud/pubsub_v1/subscriber/client.py @@ -25,6 +25,7 @@ from google.cloud.pubsub_v1 import _gapic from google.cloud.pubsub_v1 import types from google.cloud.pubsub_v1.gapic import subscriber_client +from google.cloud.pubsub_v1.gapic.transports import subscriber_grpc_transport from google.cloud.pubsub_v1.subscriber import futures from google.cloud.pubsub_v1.subscriber._protocol import streaming_pull_manager @@ -66,17 +67,25 @@ def __init__(self, **kwargs): # Use a custom channel. # We need this in order to set appropriate default message size and # keepalive options. - if "channel" not in kwargs: - kwargs["channel"] = grpc_helpers.create_channel( - credentials=kwargs.pop("credentials", None), - target=self.target, - scopes=subscriber_client.SubscriberClient._DEFAULT_SCOPES, - options={ - "grpc.max_send_message_length": -1, - "grpc.max_receive_message_length": -1, - "grpc.keepalive_time_ms": 30000, - }.items(), + if "transport" not in kwargs: + channel = kwargs.pop("channel", None) + if channel is None: + channel = grpc_helpers.create_channel( + credentials=kwargs.pop("credentials", None), + target=self.target, + scopes=subscriber_client.SubscriberClient._DEFAULT_SCOPES, + options={ + "grpc.max_send_message_length": -1, + "grpc.max_receive_message_length": -1, + "grpc.keepalive_time_ms": 30000, + }.items(), + ) + # cannot pass both 'channel' and 'credentials' + kwargs.pop("credentials", None) + transport = subscriber_grpc_transport.SubscriberGrpcTransport( + channel=channel ) + kwargs["transport"] = transport # Add the metrics headers, and instantiate the underlying GAPIC # client. diff --git a/pubsub/tests/unit/pubsub_v1/publisher/test_publisher_client.py b/pubsub/tests/unit/pubsub_v1/publisher/test_publisher_client.py index a141e1f12187..05e4c8c67209 100644 --- a/pubsub/tests/unit/pubsub_v1/publisher/test_publisher_client.py +++ b/pubsub/tests/unit/pubsub_v1/publisher/test_publisher_client.py @@ -36,6 +36,19 @@ def test_init(): assert client.batch_settings.max_messages == 1000 +def test_init_w_custom_transport(): + transport = object() + client = publisher.Client(transport=transport) + + # A plain client should have an `api` (the underlying GAPIC) and a + # batch settings object, which should have the defaults. + assert isinstance(client.api, publisher_client.PublisherClient) + assert client.api.transport is transport + assert client.batch_settings.max_bytes == 10 * 1000 * 1000 + assert client.batch_settings.max_latency == 0.05 + assert client.batch_settings.max_messages == 1000 + + def test_init_emulator(monkeypatch): monkeypatch.setenv("PUBSUB_EMULATOR_HOST", "/foo/bar/") # NOTE: When the emulator host is set, a custom channel will be used, so diff --git a/pubsub/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py b/pubsub/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py index 5acd5b6f8dd7..d4914fee8f5b 100644 --- a/pubsub/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py +++ b/pubsub/tests/unit/pubsub_v1/subscriber/test_subscriber_client.py @@ -16,6 +16,7 @@ import mock from google.cloud.pubsub_v1 import subscriber +from google.cloud.pubsub_v1.gapic import subscriber_client from google.cloud.pubsub_v1 import types from google.cloud.pubsub_v1.subscriber import futures @@ -23,7 +24,14 @@ def test_init(): creds = mock.Mock(spec=credentials.Credentials) client = subscriber.Client(credentials=creds) - assert client.api is not None + assert isinstance(client.api, subscriber_client.SubscriberClient) + + +def test_init_w_custom_transport(): + transport = object() + client = subscriber.Client(transport=transport) + assert isinstance(client.api, subscriber_client.SubscriberClient) + assert client.api.transport is transport def test_init_emulator(monkeypatch):