diff --git a/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py b/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py index 1ec44e21c..a6719ef5a 100644 --- a/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py +++ b/packages/service/ni_measurement_plugin_sdk_service/grpc/channelpool.py @@ -2,16 +2,13 @@ from __future__ import annotations +import ipaddress +import re import sys from threading import Lock from types import TracebackType -from typing import ( - Dict, - Literal, - Optional, - Type, - TYPE_CHECKING, -) +from typing import TYPE_CHECKING, Dict, Literal, Optional, Type +from urllib.parse import urlparse import grpc @@ -57,9 +54,7 @@ def get_channel(self, target: str) -> grpc.Channel: with self._lock: if target not in self._channel_cache: self._lock.release() - new_channel = grpc.insecure_channel(target) - if ClientLogger.is_enabled(): - new_channel = grpc.intercept_channel(new_channel, ClientLogger()) + new_channel = self._create_channel(target) self._lock.acquire() if target not in self._channel_cache: self._channel_cache[target] = new_channel @@ -78,3 +73,44 @@ def close(self) -> None: for channel in self._channel_cache.values(): channel.close() self._channel_cache.clear() + + def _create_channel(self, target: str) -> grpc.Channel: + options = [ + ("grpc.max_receive_message_length", -1), + ("grpc.max_send_message_length", -1), + ] + if self._is_local(target): + options.append(("grpc.enable_http_proxy", 0)) + channel = grpc.insecure_channel(target, options) + if ClientLogger.is_enabled(): + channel = grpc.intercept_channel(channel, ClientLogger()) + return channel + + def _is_local(self, target: str) -> bool: + hostname = "" + # First, check if the target string is in URL format + parse_result = urlparse(target) + if parse_result.scheme and parse_result.hostname and parse_result.port: + hostname = parse_result.hostname + else: + # Next, check for target string in : format + match = re.match(r"^(.*):(\d+)$", target) + if match: + hostname = match.group(1) + + if not hostname: + return False + if hostname == "localhost" or hostname == "LOCALHOST": + return True + + # IPv6 addresses don't support parsing with leading/trailing brackets + # so we need to remove them. + match = re.match(r"^\[(.*)\]$", hostname) + if match: + hostname = match.group(1) + + try: + address = ipaddress.ip_address(hostname) + return address.is_loopback + except ValueError: + return False diff --git a/packages/service/tests/unit/grpc/__init__.py b/packages/service/tests/unit/grpc/__init__.py new file mode 100644 index 000000000..c8b601ecb --- /dev/null +++ b/packages/service/tests/unit/grpc/__init__.py @@ -0,0 +1 @@ +"""Unit tests for ni_measurement_plugin_sdk_service.grpc.""" diff --git a/packages/service/tests/unit/grpc/channelpool/__init__.py b/packages/service/tests/unit/grpc/channelpool/__init__.py new file mode 100644 index 000000000..48a9614f8 --- /dev/null +++ b/packages/service/tests/unit/grpc/channelpool/__init__.py @@ -0,0 +1 @@ +"""Unit tests for ni_measurement_plugin_sdk_service.grpc.channelpool.""" diff --git a/packages/service/tests/unit/grpc/channelpool/test_channel_pool.py b/packages/service/tests/unit/grpc/channelpool/test_channel_pool.py new file mode 100644 index 000000000..8d402fc14 --- /dev/null +++ b/packages/service/tests/unit/grpc/channelpool/test_channel_pool.py @@ -0,0 +1,32 @@ +import pytest + +from ni_measurement_plugin_sdk_service.grpc.channelpool import GrpcChannelPool + + +@pytest.mark.parametrize( + "target,expected_result", + [ + ("127.0.0.1", False), # Port must be specified explicitly + ("[::1]", False), # Port must be specified explicitly + ("localhost", False), # Port must be specified explicitly + ("127.0.0.1:100", True), + ("[::1]:100", True), + ("localhost:100", True), + ("http://127.0.0.1", False), # Port must be specified explicitly + ("http://[::1]", False), # Port must be specified explicitly + ("http://localhost", False), # Port must be specified explicitly + ("http://127.0.0.1:100", True), + ("http://[::1]:100", True), + ("http://localhost:100", True), + ("1.1.1.1:100", False), + ("http://www.google.com:80", False), + ], +) +def test___channel_pool___is_local___returns_expected_result( + target: str, expected_result: bool +) -> None: + channel_pool = GrpcChannelPool() + + result = channel_pool._is_local(target) + + assert result == expected_result