diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 64b190f..8c3b07a 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -6,11 +6,12 @@ ## Upgrading - +- `channel.parse_grpc_uri()` takes an extra argument, the channel type (which can be either `grpclib.client.Channel` or `grpcio.aio.Channel`). ## New Features - Add a `exception` module to provide client exceptions, including gRPC errors with one subclass per gRPC error status code. +- `channel.parse_grpc_uri()` can now be used with `grpcio` too. ## Bug Fixes diff --git a/src/frequenz/client/base/_grpchacks.py b/src/frequenz/client/base/_grpchacks.py index 04378b6..274c862 100644 --- a/src/frequenz/client/base/_grpchacks.py +++ b/src/frequenz/client/base/_grpchacks.py @@ -3,22 +3,46 @@ """Hacks to deal with multiple grpc libraries. -This module conditionally imports the base exceptions from the `grpclib` and `grpcio` -libraries, assigning them a new name: +This module conditionally imports symbols from the `grpclib` and `grpcio` libraries, +assigning them a new name. -- [`GrpclibError`][] for [`grpclib.GRPCError`][] -- [`GrpcioError`][] for [`grpc.aio.AioRpcError`][] +for `grpclib`: -If the libraries are not installed, the module defines dummy classes with the same names +- `GrpclibError` for `grpclib.GRPCError` +- `GrpclibChannel` for `grpclib.client.Channel` + +For `grpcio`: + +- `GrpcioError` for `grpc.aio.AioRpcError` +- `GrpcioChannel` for `grpc.aio.Channel` +- `GrpcioSslChannelCredentials` for `grpc.ssl_channel_credentials` +- `grpcio_insecure_channel` for `grpc.aio.insecure_channel` +- `grpcio_secure_channel` for `grpc.aio.secure_channel` + +If the libraries are not installed, the module defines dummy symbols with the same names to avoid import errors. -This way exceptions can be caught from both libraries independently of which one is -used. The unused library will just never raise any exceptions. +This way exceptions code can be written to work with both libraries assuming both are +aviailable, and the correct symbols will be imported at runtime. """ +from typing import Any, Self + +__all__ = [ + "GrpcioChannel", + "GrpcioChannelCredentials", + "GrpcioError", + "GrpclibChannel", + "GrpclibError", + "grpcio_insecure_channel", + "grpcio_secure_channel", + "grpcio_ssl_channel_credentials", +] + try: from grpclib import GRPCError as GrpclibError + from grpclib.client import Channel as GrpclibChannel except ImportError: class GrpclibError(Exception): # type: ignore[no-redef] @@ -29,11 +53,57 @@ class GrpclibError(Exception): # type: ignore[no-redef] this class will never be instantiated. """ + class GrpclibChannel: # type: ignore[no-redef] + """A dummy class to avoid import errors. + + This class will never be actually used, as it is only used for catching + exceptions from the grpclib library. If the grpclib library is not installed, + this class will never be instantiated. + """ + + def __init__(self, target: str): + """Create an instance.""" + + async def __aenter__(self) -> Self: + """Enter a context manager.""" + return self + + async def __aexit__( + self, + _exc_type: type[BaseException] | None, + _exc_val: BaseException | None, + _exc_tb: Any | None, + ) -> bool | None: + """Exit a context manager.""" + return None + try: + from grpc import ChannelCredentials as GrpcioChannelCredentials + from grpc import ssl_channel_credentials as grpcio_ssl_channel_credentials from grpc.aio import AioRpcError as GrpcioError + from grpc.aio import Channel as GrpcioChannel + from grpc.aio import insecure_channel as grpcio_insecure_channel + from grpc.aio import secure_channel as grpcio_secure_channel except ImportError: + class GrpcioChannelCredentials: # type: ignore[no-redef] + """A dummy class to avoid import errors. + + This class will never be actually used, as it is only used for catching + exceptions from the grpc library. If the grpc library is not installed, + this class will never be instantiated. + """ + + def grpcio_ssl_channel_credentials() -> GrpcioChannelCredentials: # type: ignore[misc] + """Create a dummy function to avoid import errors. + + This function will never be actually used, as it is only used for catching + exceptions from the grpc library. If the grpc library is not installed, + this function will never be called. + """ + return GrpcioChannelCredentials() + class GrpcioError(Exception): # type: ignore[no-redef] """A dummy class to avoid import errors. @@ -42,5 +112,46 @@ class GrpcioError(Exception): # type: ignore[no-redef] this class will never be instantiated. """ + class GrpcioChannel: # type: ignore[no-redef] + """A dummy class to avoid import errors. + + This class will never be actually used, as it is only used for catching + exceptions from the grpc library. If the grpc library is not installed, + this class will never be instantiated. + """ + + async def __aenter__(self) -> Self: + """Enter a context manager.""" + return self + + async def __aexit__( + self, + _exc_type: type[BaseException] | None, + _exc_val: BaseException | None, + _exc_tb: Any | None, + ) -> bool | None: + """Exit a context manager.""" + return None + + def grpcio_insecure_channel( # type: ignore[misc] + target: str, # pylint: disable=unused-argument + ) -> GrpcioChannel: + """Create a dummy function to avoid import errors. + + This function will never be actually used, as it is only used for catching + exceptions from the grpc library. If the grpc library is not installed, + this function will never be called. + """ + return GrpcioChannel() + + def grpcio_secure_channel( # type: ignore[misc] + target: str, # pylint: disable=unused-argument + credentials: GrpcioChannelCredentials, # pylint: disable=unused-argument + ) -> GrpcioChannel: + """Create a dummy function to avoid import errors. -__all__ = ["GrpclibError", "GrpcioError"] + This function will never be actually used, as it is only used for catching + exceptions from the grpc library. If the grpc library is not installed, + this function will never be called. + """ + return GrpcioChannel() diff --git a/src/frequenz/client/base/channel.py b/src/frequenz/client/base/channel.py index dafdce0..c6f4b7b 100644 --- a/src/frequenz/client/base/channel.py +++ b/src/frequenz/client/base/channel.py @@ -3,9 +3,10 @@ """Handling of gRPC channels.""" +from typing import TypeVar from urllib.parse import parse_qs, urlparse -from grpclib.client import Channel +from . import _grpchacks def _to_bool(value: str) -> bool: @@ -17,7 +18,13 @@ def _to_bool(value: str) -> bool: raise ValueError(f"Invalid boolean value '{value}'") -def parse_grpc_uri(uri: str, /, *, default_port: int = 9090) -> Channel: +ChannelT = TypeVar("ChannelT", _grpchacks.GrpclibChannel, _grpchacks.GrpcioChannel) +"""A `grpclib` or `grpcio` channel type.""" + + +def parse_grpc_uri( + uri: str, channel_type: type[ChannelT], /, *, default_port: int = 9090 +) -> ChannelT: """Create a grpclib client channel from a URI. The URI must have the following format: @@ -38,6 +45,7 @@ def parse_grpc_uri(uri: str, /, *, default_port: int = 9090) -> Channel: Args: uri: The gRPC URI specifying the connection parameters. + channel_type: The type of channel to create. default_port: The default port number to use if the URI does not specify one. Returns: @@ -68,8 +76,19 @@ def parse_grpc_uri(uri: str, /, *, default_port: int = 9090) -> Channel: uri, ) - return Channel( - host=parsed_uri.hostname, - port=parsed_uri.port or default_port, - ssl=ssl, - ) + host = parsed_uri.hostname + port = parsed_uri.port or default_port + match channel_type: + case _grpchacks.GrpcioChannel: + target = f"{host}:{port}" + return ( + _grpchacks.grpcio_secure_channel( + target, _grpchacks.grpcio_ssl_channel_credentials() + ) + if ssl + else _grpchacks.grpcio_insecure_channel(target) + ) + case _grpchacks.GrpclibChannel: + return _grpchacks.GrpclibChannel(host=host, port=port, ssl=ssl) + case _: + assert False, "Unexpected channel type: {channel_type}" diff --git a/tests/test_channel.py b/tests/test_channel.py index 5feb366..4e30104 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -3,52 +3,52 @@ """Test cases for the channel module.""" -import unittest.mock from dataclasses import dataclass +from unittest import mock import pytest +from frequenz.client.base import _grpchacks from frequenz.client.base.channel import parse_grpc_uri +VALID_URLS = [ + ("grpc://localhost", "localhost", 9090, False), + ("grpc://localhost:1234", "localhost", 1234, False), + ("grpc://localhost:1234?ssl=true", "localhost", 1234, True), + ("grpc://localhost:1234?ssl=false", "localhost", 1234, False), + ("grpc://localhost:1234?ssl=1", "localhost", 1234, True), + ("grpc://localhost:1234?ssl=0", "localhost", 1234, False), + ("grpc://localhost:1234?ssl=on", "localhost", 1234, True), + ("grpc://localhost:1234?ssl=off", "localhost", 1234, False), + ("grpc://localhost:1234?ssl=TRUE", "localhost", 1234, True), + ("grpc://localhost:1234?ssl=FALSE", "localhost", 1234, False), + ("grpc://localhost:1234?ssl=ON", "localhost", 1234, True), + ("grpc://localhost:1234?ssl=OFF", "localhost", 1234, False), + ("grpc://localhost:1234?ssl=0&ssl=1", "localhost", 1234, True), + ("grpc://localhost:1234?ssl=1&ssl=0", "localhost", 1234, False), +] -@dataclass(frozen=True) -class _FakeChannel: - host: str - port: int - ssl: bool - -@pytest.mark.parametrize( - "uri, host, port, ssl", - [ - ("grpc://localhost", "localhost", 9090, False), - ("grpc://localhost:1234", "localhost", 1234, False), - ("grpc://localhost:1234?ssl=true", "localhost", 1234, True), - ("grpc://localhost:1234?ssl=false", "localhost", 1234, False), - ("grpc://localhost:1234?ssl=1", "localhost", 1234, True), - ("grpc://localhost:1234?ssl=0", "localhost", 1234, False), - ("grpc://localhost:1234?ssl=on", "localhost", 1234, True), - ("grpc://localhost:1234?ssl=off", "localhost", 1234, False), - ("grpc://localhost:1234?ssl=TRUE", "localhost", 1234, True), - ("grpc://localhost:1234?ssl=FALSE", "localhost", 1234, False), - ("grpc://localhost:1234?ssl=ON", "localhost", 1234, True), - ("grpc://localhost:1234?ssl=OFF", "localhost", 1234, False), - ("grpc://localhost:1234?ssl=0&ssl=1", "localhost", 1234, True), - ("grpc://localhost:1234?ssl=1&ssl=0", "localhost", 1234, False), - ], -) -def test_parse_uri_ok( +@pytest.mark.parametrize("uri, host, port, ssl", VALID_URLS) +def test_grpclib_parse_uri_ok( uri: str, host: str, port: int, ssl: bool, ) -> None: - """Test successful parsing of gRPC URIs.""" - with unittest.mock.patch( - "frequenz.client.base.channel.Channel", + """Test successful parsing of gRPC URIs using grpclib.""" + + @dataclass(frozen=True) + class _FakeChannel: + host: str + port: int + ssl: bool + + with mock.patch( + "frequenz.client.base.channel._grpchacks.GrpclibChannel", return_value=_FakeChannel(host, port, ssl), ): - channel = parse_grpc_uri(uri) + channel = parse_grpc_uri(uri, _grpchacks.GrpclibChannel) assert isinstance(channel, _FakeChannel) assert channel.host == host @@ -56,30 +56,79 @@ def test_parse_uri_ok( assert channel.ssl == ssl +@pytest.mark.parametrize("uri, host, port, ssl", VALID_URLS) +def test_grpcio_parse_uri_ok( + uri: str, + host: str, + port: int, + ssl: bool, +) -> None: + """Test successful parsing of gRPC URIs using grpcio.""" + expected_channel = mock.MagicMock( + name="mock_channel", spec=_grpchacks.GrpcioChannel + ) + expected_credentials = mock.MagicMock( + name="mock_credentials", spec=_grpchacks.GrpcioChannel + ) + + with ( + mock.patch( + "frequenz.client.base.channel._grpchacks.grpcio_insecure_channel", + return_value=expected_channel, + ) as insecure_channel_mock, + mock.patch( + "frequenz.client.base.channel._grpchacks.grpcio_secure_channel", + return_value=expected_channel, + ) as secure_channel_mock, + mock.patch( + "frequenz.client.base.channel._grpchacks.grpcio_ssl_channel_credentials", + return_value=expected_credentials, + ) as ssl_channel_credentials_mock, + ): + channel = parse_grpc_uri(uri, _grpchacks.GrpcioChannel) + + assert channel == expected_channel + expected_target = f"{host}:{port}" + if ssl: + ssl_channel_credentials_mock.assert_called_once_with() + secure_channel_mock.assert_called_once_with( + expected_target, expected_credentials + ) + else: + insecure_channel_mock.assert_called_once_with(expected_target) + + +INVALID_URLS = [ + ("http://localhost", "Invalid scheme 'http' in the URI, expected 'grpc'"), + ("grpc://", "Host name is missing in URI 'grpc://'"), + ("grpc://localhost:1234?ssl=invalid", "Invalid boolean value 'invalid'"), + ("grpc://localhost:1234?ssl=1&ssl=invalid", "Invalid boolean value 'invalid'"), + ("grpc://:1234", "Host name is missing"), + ("grpc://host:1234;param", "Port could not be cast to integer value"), + ("grpc://host:1234/path", "Unexpected path '/path'"), + ("grpc://host:1234#frag", "Unexpected fragment 'frag'"), + ("grpc://user@host:1234", "Unexpected username 'user'"), + ("grpc://:pass@host:1234?user:pass", "Unexpected password 'pass'"), + ( + "grpc://localhost?ssl=1&ssl=1&ssl=invalid", + "Invalid boolean value 'invalid'", + ), + ( + "grpc://localhost:1234?ssl=1&ffl=true", + "Unexpected query parameters {'ffl': 'true'}", + ), +] + + +@pytest.mark.parametrize("uri, error_msg", INVALID_URLS) @pytest.mark.parametrize( - "uri, error_msg", - [ - ("http://localhost", "Invalid scheme 'http' in the URI, expected 'grpc'"), - ("grpc://", "Host name is missing in URI 'grpc://'"), - ("grpc://localhost:1234?ssl=invalid", "Invalid boolean value 'invalid'"), - ("grpc://localhost:1234?ssl=1&ssl=invalid", "Invalid boolean value 'invalid'"), - ("grpc://:1234", "Host name is missing"), - ("grpc://host:1234;param", "Port could not be cast to integer value"), - ("grpc://host:1234/path", "Unexpected path '/path'"), - ("grpc://host:1234#frag", "Unexpected fragment 'frag'"), - ("grpc://user@host:1234", "Unexpected username 'user'"), - ("grpc://:pass@host:1234?user:pass", "Unexpected password 'pass'"), - ( - "grpc://localhost?ssl=1&ssl=1&ssl=invalid", - "Invalid boolean value 'invalid'", - ), - ( - "grpc://localhost:1234?ssl=1&ffl=true", - "Unexpected query parameters {'ffl': 'true'}", - ), - ], + "channel_type", [_grpchacks.GrpclibChannel, _grpchacks.GrpcioChannel], ids=str ) -def test_parse_uri_error(uri: str, error_msg: str) -> None: - """Test parsing of invalid gRPC URIs.""" +def test_grpclib_parse_uri_error( + uri: str, + error_msg: str, + channel_type: type, +) -> None: + """Test parsing of invalid gRPC URIs for grpclib.""" with pytest.raises(ValueError, match=error_msg): - parse_grpc_uri(uri) + parse_grpc_uri(uri, channel_type)