diff --git a/google/api_core/bidi.py b/google/api_core/bidi.py index 57f5f9dd..74abc495 100644 --- a/google/api_core/bidi.py +++ b/google/api_core/bidi.py @@ -92,10 +92,7 @@ def _is_active(self): # Note: there is a possibility that this starts *before* the call # property is set. So we have to check if self.call is set before # seeing if it's active. - if self.call is not None and not self.call.is_active(): - return False - else: - return True + return self.call is not None and self.call.is_active() def __iter__(self): if self._initial_request is not None: @@ -265,6 +262,10 @@ def add_done_callback(self, callback): self._callbacks.append(callback) def _on_call_done(self, future): + # This occurs when the RPC errors or is successfully terminated. + # Note that grpc's "future" here can also be a grpc.RpcError. + # See note in https://github.com/grpc/grpc/issues/10885#issuecomment-302651331 + # that `grpc.RpcError` is also `grpc.call`. for callback in self._callbacks: callback(future) @@ -276,7 +277,13 @@ def open(self): request_generator = _RequestQueueGenerator( self._request_queue, initial_request=self._initial_request ) - call = self._start_rpc(iter(request_generator), metadata=self._rpc_metadata) + try: + call = self._start_rpc(iter(request_generator), metadata=self._rpc_metadata) + except exceptions.GoogleAPICallError as exc: + # The original `grpc.RpcError` (which is usually also a `grpc.Call`) is + # available from the ``response`` property on the mapped exception. + self._on_call_done(exc.response) + raise request_generator.call = call diff --git a/tests/unit/test_bidi.py b/tests/unit/test_bidi.py index f5e2b72b..84ac9dc5 100644 --- a/tests/unit/test_bidi.py +++ b/tests/unit/test_bidi.py @@ -804,6 +804,21 @@ def test_wake_on_error(self): while consumer.is_active: pass + def test_rpc_callback_fires_when_consumer_start_fails(self): + expected_exception = exceptions.InvalidArgument( + "test", response=grpc.StatusCode.INVALID_ARGUMENT + ) + callback = mock.Mock(spec=["__call__"]) + + rpc, _ = make_rpc() + bidi_rpc = bidi.BidiRpc(rpc) + bidi_rpc.add_done_callback(callback) + bidi_rpc._start_rpc.side_effect = expected_exception + + consumer = bidi.BackgroundConsumer(bidi_rpc, on_response=None) + consumer.start() + assert callback.call_args.args[0] == grpc.StatusCode.INVALID_ARGUMENT + def test_consumer_expected_error(self, caplog): caplog.set_level(logging.DEBUG)