diff --git a/poetry.lock b/poetry.lock index 5fd21633..c5cbf7bc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. [[package]] name = "astroid" @@ -70,7 +70,7 @@ description = "Foreign Function Interface for Python calling C code." optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"true\" and platform_python_implementation != \"PyPy\"" +markers = "platform_python_implementation != \"PyPy\"" files = [ {file = "cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14"}, {file = "cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67"}, @@ -475,7 +475,7 @@ description = "cryptography is a package which provides cryptographic recipes an optional = true python-versions = ">=3.7" groups = ["main"] -markers = "python_version < \"3.10\" and extra == \"true\"" +markers = "python_version < \"3.10\"" files = [ {file = "cryptography-43.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:bf7a1932ac4176486eab36a19ed4c0492da5d97123f1406cf15e41b05e787d2e"}, {file = "cryptography-43.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:63efa177ff54aec6e1c0aefaa1a241232dcd37413835a9b674b6e3f0ae2bfd3e"}, @@ -526,7 +526,7 @@ description = "cryptography is a package which provides cryptographic recipes an optional = true python-versions = "!=3.9.0,!=3.9.1,>=3.7" groups = ["main"] -markers = "python_version >= \"3.10\" and extra == \"true\"" +markers = "python_version >= \"3.10\"" files = [ {file = "cryptography-45.0.6-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:048e7ad9e08cf4c0ab07ff7f36cc3115924e22e2266e034450a890d9e312dd74"}, {file = "cryptography-45.0.6-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:44647c5d796f5fc042bbc6d61307d04bf29bccb74d188f18051b635f20a9c75f"}, @@ -587,7 +587,7 @@ description = "Decorators for Humans" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"true\" and sys_platform != \"win32\"" +markers = "sys_platform != \"win32\"" files = [ {file = "decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a"}, {file = "decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360"}, @@ -644,7 +644,7 @@ description = "Python GSSAPI Wrapper" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"true\" and sys_platform != \"win32\"" +markers = "sys_platform != \"win32\"" files = [ {file = "gssapi-1.9.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:261e00ac426d840055ddb2199f4989db7e3ce70fa18b1538f53e392b4823e8f1"}, {file = "gssapi-1.9.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:14a1ae12fdf1e4c8889206195ba1843de09fe82587fa113112887cd5894587c6"}, @@ -725,7 +725,7 @@ description = "Kerberos API bindings for Python" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"true\" and sys_platform != \"win32\"" +markers = "sys_platform != \"win32\"" files = [ {file = "krb5-0.7.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:cbdcd2c4514af5ca32d189bc31f30fee2ab297dcbff74a53bd82f92ad1f6e0ef"}, {file = "krb5-0.7.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:40ad837d563865946cffd65a588f24876da2809aa5ce4412de49442d7cf11d50"}, @@ -1333,6 +1333,38 @@ files = [ [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] +[[package]] +name = "pybreaker" +version = "1.2.0" +description = "Python implementation of the Circuit Breaker pattern" +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.10\"" +files = [ + {file = "pybreaker-1.2.0-py3-none-any.whl", hash = "sha256:c3e7683e29ecb3d4421265aaea55504f1186a2fdc1f17b6b091d80d1e1eb5ede"}, + {file = "pybreaker-1.2.0.tar.gz", hash = "sha256:18707776316f93a30c1be0e4fec1f8aa5ed19d7e395a218eb2f050c8524fb2dc"}, +] + +[package.extras] +test = ["fakeredis", "mock", "pytest", "redis", "tornado", "types-mock", "types-redis"] + +[[package]] +name = "pybreaker" +version = "1.4.1" +description = "Python implementation of the Circuit Breaker pattern" +optional = false +python-versions = ">=3.9" +groups = ["main"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "pybreaker-1.4.1-py3-none-any.whl", hash = "sha256:b4dab4a05195b7f2a64a6c1a6c4ba7a96534ef56ea7210e6bcb59f28897160e0"}, + {file = "pybreaker-1.4.1.tar.gz", hash = "sha256:8df2d245c73ba40c8242c56ffb4f12138fbadc23e296224740c2028ea9dc1178"}, +] + +[package.extras] +test = ["fakeredis", "mock", "pytest", "redis", "tornado", "types-mock", "types-redis"] + [[package]] name = "pycparser" version = "2.22" @@ -1340,7 +1372,7 @@ description = "C parser in Python" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"true\" and platform_python_implementation != \"PyPy\"" +markers = "platform_python_implementation != \"PyPy\"" files = [ {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, @@ -1422,7 +1454,6 @@ description = "Windows Negotiate Authentication Client and Server" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"true\"" files = [ {file = "pyspnego-0.11.2-py3-none-any.whl", hash = "sha256:74abc1fb51e59360eb5c5c9086e5962174f1072c7a50cf6da0bda9a4bcfdfbd4"}, {file = "pyspnego-0.11.2.tar.gz", hash = "sha256:994388d308fb06e4498365ce78d222bf4f3570b6df4ec95738431f61510c971b"}, @@ -1567,7 +1598,6 @@ description = "A Kerberos authentication handler for python-requests" optional = true python-versions = ">=3.6" groups = ["main"] -markers = "extra == \"true\"" files = [ {file = "requests_kerberos-0.15.0-py2.py3-none-any.whl", hash = "sha256:ba9b0980b8489c93bfb13854fd118834e576d6700bfea3745cb2e62278cd16a6"}, {file = "requests_kerberos-0.15.0.tar.gz", hash = "sha256:437512e424413d8113181d696e56694ffa4259eb9a5fc4e803926963864eaf4e"}, @@ -1597,7 +1627,7 @@ description = "SSPI API bindings for Python" optional = true python-versions = ">=3.8" groups = ["main"] -markers = "extra == \"true\" and sys_platform == \"win32\" and python_version < \"3.10\"" +markers = "python_version < \"3.10\" and sys_platform == \"win32\"" files = [ {file = "sspilib-0.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:34f566ba8b332c91594e21a71200de2d4ce55ca5a205541d4128ed23e3c98777"}, {file = "sspilib-0.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b11e4f030de5c5de0f29bcf41a6e87c9fd90cb3b0f64e446a6e1d1aef4d08f5"}, @@ -1644,7 +1674,7 @@ description = "SSPI API bindings for Python" optional = true python-versions = ">=3.9" groups = ["main"] -markers = "extra == \"true\" and sys_platform == \"win32\" and python_version >= \"3.10\"" +markers = "python_version >= \"3.10\" and sys_platform == \"win32\"" files = [ {file = "sspilib-0.3.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:c45860bdc4793af572d365434020ff5a1ef78c42a2fc2c7a7d8e44eacaf475b6"}, {file = "sspilib-0.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:62cc4de547503dec13b81a6af82b398e9ef53ea82c3535418d7d069c7a05d5cd"}, @@ -1797,9 +1827,8 @@ zstd = ["zstandard (>=0.18.0)"] [extras] pyarrow = ["pyarrow", "pyarrow"] -true = ["requests-kerberos"] [metadata] lock-version = "2.1" python-versions = "^3.8.0" -content-hash = "ddc7354d47a940fa40b4d34c43a1c42488b01258d09d771d58d64a0dfaf0b955" +content-hash = "44e18fe57647fd472bc311393a37cb6f3f66c3f061d9d4204bd346d53b4ade4a" diff --git a/pyproject.toml b/pyproject.toml index a1f43bc7..fa8619da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ pyarrow = [ { version = ">=18.0.0", python = ">=3.13", optional=true } ] pyjwt = "^2.0.0" +pybreaker = "^1.0.0" requests-kerberos = {version = "^0.15.0", optional = true} diff --git a/src/databricks/sql/auth/common.py b/src/databricks/sql/auth/common.py index 679e353f..8fb9bae9 100644 --- a/src/databricks/sql/auth/common.py +++ b/src/databricks/sql/auth/common.py @@ -50,6 +50,7 @@ def __init__( pool_connections: Optional[int] = None, pool_maxsize: Optional[int] = None, user_agent: Optional[str] = None, + telemetry_circuit_breaker_enabled: Optional[bool] = None, ): self.hostname = hostname self.access_token = access_token @@ -81,6 +82,11 @@ def __init__( self.pool_connections = pool_connections or 10 self.pool_maxsize = pool_maxsize or 20 self.user_agent = user_agent + self.telemetry_circuit_breaker_enabled = ( + telemetry_circuit_breaker_enabled + if telemetry_circuit_breaker_enabled is not None + else False + ) def get_effective_azure_login_app_id(hostname) -> str: diff --git a/src/databricks/sql/telemetry/circuit_breaker_manager.py b/src/databricks/sql/telemetry/circuit_breaker_manager.py new file mode 100644 index 00000000..86498e47 --- /dev/null +++ b/src/databricks/sql/telemetry/circuit_breaker_manager.py @@ -0,0 +1,138 @@ +""" +Circuit breaker implementation for telemetry requests. + +This module provides circuit breaker functionality to prevent telemetry failures +from impacting the main SQL operations. It uses pybreaker library to implement +the circuit breaker pattern with configurable thresholds and timeouts. +""" + +import logging +import threading +from typing import Dict, Optional, Any +from dataclasses import dataclass + +import pybreaker +from pybreaker import CircuitBreaker, CircuitBreakerError, CircuitBreakerListener + +logger = logging.getLogger(__name__) + +# Circuit Breaker Configuration Constants +MINIMUM_CALLS = 20 +RESET_TIMEOUT = 30 +CIRCUIT_BREAKER_NAME = "telemetry-circuit-breaker" + +# Circuit Breaker State Constants +CIRCUIT_BREAKER_STATE_OPEN = "open" +CIRCUIT_BREAKER_STATE_CLOSED = "closed" +CIRCUIT_BREAKER_STATE_HALF_OPEN = "half-open" +CIRCUIT_BREAKER_STATE_DISABLED = "disabled" + +# Logging Message Constants +LOG_CIRCUIT_BREAKER_STATE_CHANGED = "Circuit breaker state changed from %s to %s for %s" +LOG_CIRCUIT_BREAKER_OPENED = ( + "Circuit breaker opened for %s - telemetry requests will be blocked" +) +LOG_CIRCUIT_BREAKER_CLOSED = ( + "Circuit breaker closed for %s - telemetry requests will be allowed" +) +LOG_CIRCUIT_BREAKER_HALF_OPEN = ( + "Circuit breaker half-open for %s - testing telemetry requests" +) + + +class CircuitBreakerStateListener(CircuitBreakerListener): + """Listener for circuit breaker state changes.""" + + def before_call(self, cb: CircuitBreaker, func, *args, **kwargs) -> None: + """Called before the circuit breaker calls a function.""" + pass + + def failure(self, cb: CircuitBreaker, exc: BaseException) -> None: + """Called when a function called by the circuit breaker fails.""" + pass + + def success(self, cb: CircuitBreaker) -> None: + """Called when a function called by the circuit breaker succeeds.""" + pass + + def state_change(self, cb: CircuitBreaker, old_state, new_state) -> None: + """Called when the circuit breaker state changes.""" + old_state_name = old_state.name if old_state else "None" + new_state_name = new_state.name if new_state else "None" + + logger.info( + LOG_CIRCUIT_BREAKER_STATE_CHANGED, old_state_name, new_state_name, cb.name + ) + + if new_state_name == CIRCUIT_BREAKER_STATE_OPEN: + logger.warning(LOG_CIRCUIT_BREAKER_OPENED, cb.name) + elif new_state_name == CIRCUIT_BREAKER_STATE_CLOSED: + logger.info(LOG_CIRCUIT_BREAKER_CLOSED, cb.name) + elif new_state_name == CIRCUIT_BREAKER_STATE_HALF_OPEN: + logger.info(LOG_CIRCUIT_BREAKER_HALF_OPEN, cb.name) + + +class CircuitBreakerManager: + """ + Manages circuit breaker instances for telemetry requests. + + This class provides a singleton pattern to manage circuit breaker instances + per host, ensuring that telemetry failures don't impact main SQL operations. + + Circuit breaker configuration is fixed and cannot be overridden. + """ + + _instances: Dict[str, CircuitBreaker] = {} + _lock = threading.RLock() + + @classmethod + def get_circuit_breaker(cls, host: str) -> CircuitBreaker: + """ + Get or create a circuit breaker instance for the specified host. + + Args: + host: The hostname for which to get the circuit breaker + + Returns: + CircuitBreaker instance for the host + """ + with cls._lock: + if host not in cls._instances: + cls._instances[host] = cls._create_circuit_breaker(host) + logger.debug("Created circuit breaker for host: %s", host) + + return cls._instances[host] + + @classmethod + def _create_circuit_breaker(cls, host: str) -> CircuitBreaker: + """ + Create a new circuit breaker instance for the specified host. + + Args: + host: The hostname for the circuit breaker + + Returns: + New CircuitBreaker instance + """ + # Create circuit breaker with fixed configuration + breaker = CircuitBreaker( + fail_max=MINIMUM_CALLS, + reset_timeout=RESET_TIMEOUT, + name=f"{CIRCUIT_BREAKER_NAME}-{host}", + ) + breaker.add_listener(CircuitBreakerStateListener()) + + return breaker + + +def is_circuit_breaker_error(exception: Exception) -> bool: + """ + Check if an exception is a circuit breaker error. + + Args: + exception: The exception to check + + Returns: + True if the exception is a circuit breaker error + """ + return isinstance(exception, CircuitBreakerError) diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 71fcc40c..f42bf7b8 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -41,6 +41,14 @@ from databricks.sql.common.feature_flag import FeatureFlagsContextFactory from databricks.sql.common.unified_http_client import UnifiedHttpClient from databricks.sql.common.http import HttpMethod +from databricks.sql.telemetry.telemetry_push_client import ( + ITelemetryPushClient, + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, +) +from databricks.sql.telemetry.circuit_breaker_manager import ( + is_circuit_breaker_error, +) if TYPE_CHECKING: from databricks.sql.client import Connection @@ -189,6 +197,21 @@ def __init__( # Create own HTTP client from client context self._http_client = UnifiedHttpClient(client_context) + # Create telemetry push client based on circuit breaker enabled flag + if client_context.telemetry_circuit_breaker_enabled: + # Create circuit breaker telemetry push client with fixed configuration + self._telemetry_push_client: ITelemetryPushClient = ( + CircuitBreakerTelemetryPushClient( + TelemetryPushClient(self._http_client), + host_url, + ) + ) + else: + # Circuit breaker disabled - use direct telemetry push client + self._telemetry_push_client: ITelemetryPushClient = TelemetryPushClient( + self._http_client + ) + def _export_event(self, event): """Add an event to the batch queue and flush if batch is full""" logger.debug("Exporting event for connection %s", self._session_id_hex) @@ -252,14 +275,21 @@ def _send_telemetry(self, events): logger.debug("Failed to submit telemetry request: %s", e) def _send_with_unified_client(self, url, data, headers, timeout=900): - """Helper method to send telemetry using the unified HTTP client.""" + """Helper method to send telemetry using the telemetry push client.""" try: - response = self._http_client.request( + response = self._telemetry_push_client.request( HttpMethod.POST, url, body=data, headers=headers, timeout=timeout ) return response except Exception as e: - logger.error("Failed to send telemetry with unified client: %s", e) + if is_circuit_breaker_error(e): + logger.warning( + "Telemetry request blocked by circuit breaker for connection %s: %s", + self._session_id_hex, + e, + ) + else: + logger.error("Failed to send telemetry: %s", e) raise def _telemetry_request_callback(self, future, sent_count: int): diff --git a/src/databricks/sql/telemetry/telemetry_push_client.py b/src/databricks/sql/telemetry/telemetry_push_client.py new file mode 100644 index 00000000..532084c8 --- /dev/null +++ b/src/databricks/sql/telemetry/telemetry_push_client.py @@ -0,0 +1,173 @@ +""" +Telemetry push client interface and implementations. + +This module provides an interface for telemetry push clients with two implementations: +1. TelemetryPushClient - Direct HTTP client implementation +2. CircuitBreakerTelemetryPushClient - Circuit breaker wrapper implementation +""" + +import logging +from abc import ABC, abstractmethod +from typing import Dict, Any, Optional +from contextlib import contextmanager + +try: + from urllib3 import BaseHTTPResponse +except ImportError: + from urllib3 import HTTPResponse as BaseHTTPResponse +from pybreaker import CircuitBreakerError + +from databricks.sql.common.unified_http_client import UnifiedHttpClient +from databricks.sql.common.http import HttpMethod +from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + is_circuit_breaker_error, +) + +logger = logging.getLogger(__name__) + + +class ITelemetryPushClient(ABC): + """Interface for telemetry push clients.""" + + @abstractmethod + def request( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs + ) -> BaseHTTPResponse: + """Make an HTTP request.""" + pass + + @abstractmethod + @contextmanager + def request_context( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs + ): + """Context manager for making HTTP requests.""" + pass + + +class TelemetryPushClient(ITelemetryPushClient): + """Direct HTTP client implementation for telemetry requests.""" + + def __init__(self, http_client: UnifiedHttpClient): + """ + Initialize the telemetry push client. + + Args: + http_client: The underlying HTTP client + """ + self._http_client = http_client + logger.debug("TelemetryPushClient initialized") + + def request( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs + ) -> BaseHTTPResponse: + """Make an HTTP request using the underlying HTTP client.""" + return self._http_client.request(method, url, headers, **kwargs) + + @contextmanager + def request_context( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs + ): + """Context manager for making HTTP requests.""" + with self._http_client.request_context( + method, url, headers, **kwargs + ) as response: + yield response + + +class CircuitBreakerTelemetryPushClient(ITelemetryPushClient): + """Circuit breaker wrapper implementation for telemetry requests.""" + + def __init__(self, delegate: ITelemetryPushClient, host: str): + """ + Initialize the circuit breaker telemetry push client. + + Args: + delegate: The underlying telemetry push client to wrap + host: The hostname for circuit breaker identification + """ + self._delegate = delegate + self._host = host + + # Get circuit breaker for this host (creates if doesn't exist) + self._circuit_breaker = CircuitBreakerManager.get_circuit_breaker(host) + + logger.debug( + "CircuitBreakerTelemetryPushClient initialized for host %s", + host, + ) + + def request( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs + ) -> BaseHTTPResponse: + """Make an HTTP request with circuit breaker protection.""" + try: + # Use circuit breaker to protect the request + return self._circuit_breaker.call( + lambda: self._delegate.request(method, url, headers, **kwargs) + ) + except CircuitBreakerError as e: + logger.warning( + "Circuit breaker is open for host %s, blocking telemetry request to %s: %s", + self._host, + url, + e, + ) + raise + except Exception as e: + # Re-raise non-circuit breaker exceptions + logger.debug("Telemetry request failed for host %s: %s", self._host, e) + raise + + @contextmanager + def request_context( + self, + method: HttpMethod, + url: str, + headers: Optional[Dict[str, str]] = None, + **kwargs + ): + """Context manager for making HTTP requests with circuit breaker protection.""" + try: + # Use circuit breaker to protect the request + def _make_request(): + with self._delegate.request_context( + method, url, headers, **kwargs + ) as response: + return response + + response = self._circuit_breaker.call(_make_request) + yield response + except CircuitBreakerError as e: + logger.warning( + "Circuit breaker is open for host %s, blocking telemetry request to %s: %s", + self._host, + url, + e, + ) + raise + except Exception as e: + # Re-raise non-circuit breaker exceptions + logger.debug("Telemetry request failed for host %s: %s", self._host, e) + raise diff --git a/tests/unit/test_circuit_breaker_http_client.py b/tests/unit/test_circuit_breaker_http_client.py new file mode 100644 index 00000000..bc1347b3 --- /dev/null +++ b/tests/unit/test_circuit_breaker_http_client.py @@ -0,0 +1,234 @@ +""" +Unit tests for telemetry push client functionality. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock + +from databricks.sql.telemetry.telemetry_push_client import ( + ITelemetryPushClient, + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, +) +from databricks.sql.common.http import HttpMethod +from pybreaker import CircuitBreakerError + + +class TestTelemetryPushClient: + """Test cases for TelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_http_client = Mock() + self.client = TelemetryPushClient(self.mock_http_client) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._http_client == self.mock_http_client + + def test_request_delegates_to_http_client(self): + """Test that request delegates to underlying HTTP client.""" + mock_response = Mock() + self.mock_http_client.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_http_client.request.assert_called_once() + + def test_direct_client_has_no_circuit_breaker(self): + """Test that direct client does not have circuit breaker functionality.""" + # Direct client should work without circuit breaker + assert isinstance(self.client, TelemetryPushClient) + + +class TestCircuitBreakerTelemetryPushClient: + """Test cases for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock(spec=ITelemetryPushClient) + self.host = "test-host.example.com" + self.client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._delegate == self.mock_delegate + assert self.client._host == self.host + assert self.client._circuit_breaker is not None + + def test_request_context_enabled_success(self): + """Test successful request context when circuit breaker is enabled.""" + mock_response = Mock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_response + mock_context.__exit__.return_value = None + self.mock_delegate.request_context.return_value = mock_context + + with self.client.request_context( + HttpMethod.POST, "https://test.com", {} + ) as response: + assert response == mock_response + + self.mock_delegate.request_context.assert_called_once() + + def test_request_context_enabled_circuit_breaker_error(self): + """Test request context when circuit breaker is open.""" + # Mock circuit breaker to raise CircuitBreakerError + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + with pytest.raises(CircuitBreakerError): + with self.client.request_context( + HttpMethod.POST, "https://test.com", {} + ): + pass + + def test_request_context_enabled_other_error(self): + """Test request context when other error occurs.""" + # Mock delegate to raise a different error + self.mock_delegate.request_context.side_effect = ValueError("Network error") + + with pytest.raises(ValueError): + with self.client.request_context(HttpMethod.POST, "https://test.com", {}): + pass + + def test_request_enabled_success(self): + """Test successful request when circuit breaker is enabled.""" + mock_response = Mock() + self.mock_delegate.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_delegate.request.assert_called_once() + + def test_request_enabled_circuit_breaker_error(self): + """Test request when circuit breaker is open.""" + # Mock circuit breaker to raise CircuitBreakerError + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_enabled_other_error(self): + """Test request when other error occurs.""" + # Mock delegate to raise a different error + self.mock_delegate.request.side_effect = ValueError("Network error") + + with pytest.raises(ValueError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_is_circuit_breaker_enabled(self): + """Test checking if circuit breaker is enabled.""" + assert self.client._circuit_breaker is not None + + def test_circuit_breaker_state_logging(self): + """Test that circuit breaker state changes are logged.""" + with patch( + "databricks.sql.telemetry.telemetry_push_client.logger" + ) as mock_logger: + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + # Check that warning was logged + mock_logger.warning.assert_called() + warning_call = mock_logger.warning.call_args[0] + assert "Circuit breaker is open" in warning_call[0] + assert self.host in warning_call[1] + + def test_other_error_logging(self): + """Test that other errors are logged appropriately.""" + with patch( + "databricks.sql.telemetry.telemetry_push_client.logger" + ) as mock_logger: + self.mock_delegate.request.side_effect = ValueError("Network error") + + with pytest.raises(ValueError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + # Check that debug was logged + mock_logger.debug.assert_called() + debug_call = mock_logger.debug.call_args[0] + assert "Telemetry request failed" in debug_call[0] + assert self.host in debug_call[1] + + +class TestCircuitBreakerTelemetryPushClientIntegration: + """Integration tests for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock() + self.host = "test-host.example.com" + + def test_circuit_breaker_opens_after_failures(self): + """Test that circuit breaker opens after repeated failures.""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + ) + + # Clear any existing state + CircuitBreakerManager._instances.clear() + + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Simulate failures + self.mock_delegate.request.side_effect = Exception("Network error") + + # Trigger failures up to the threshold + for i in range(MINIMUM_CALLS): + with pytest.raises(Exception): + client.request(HttpMethod.POST, "https://test.com", {}) + + # Next call should fail with CircuitBreakerError (circuit is now open) + with pytest.raises(CircuitBreakerError): + client.request(HttpMethod.POST, "https://test.com", {}) + + def test_circuit_breaker_recovers_after_success(self): + """Test that circuit breaker recovers after successful calls.""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + MINIMUM_CALLS, + RESET_TIMEOUT, + ) + import time + + # Clear any existing state + CircuitBreakerManager._instances.clear() + + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Simulate failures first + self.mock_delegate.request.side_effect = Exception("Network error") + + # Trigger failures up to the threshold + for i in range(MINIMUM_CALLS): + with pytest.raises(Exception): + client.request(HttpMethod.POST, "https://test.com", {}) + + # Circuit should be open now + with pytest.raises(CircuitBreakerError): + client.request(HttpMethod.POST, "https://test.com", {}) + + # Wait for reset timeout + time.sleep(RESET_TIMEOUT + 0.1) + + # Simulate successful calls + self.mock_delegate.request.side_effect = None + self.mock_delegate.request.return_value = Mock() + + # Should work again + response = client.request(HttpMethod.POST, "https://test.com", {}) + assert response is not None diff --git a/tests/unit/test_circuit_breaker_manager.py b/tests/unit/test_circuit_breaker_manager.py new file mode 100644 index 00000000..62397a0e --- /dev/null +++ b/tests/unit/test_circuit_breaker_manager.py @@ -0,0 +1,249 @@ +""" +Unit tests for circuit breaker manager functionality. +""" + +import pytest +import threading +import time +from unittest.mock import Mock, patch + +from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + is_circuit_breaker_error, + MINIMUM_CALLS, + RESET_TIMEOUT, + CIRCUIT_BREAKER_NAME, +) +from pybreaker import CircuitBreakerError + + +class TestCircuitBreakerManager: + """Test cases for CircuitBreakerManager.""" + + def setup_method(self): + """Set up test fixtures.""" + # Clear any existing instances + CircuitBreakerManager._instances.clear() + + def teardown_method(self): + """Clean up after tests.""" + CircuitBreakerManager._instances.clear() + + def test_get_circuit_breaker_creates_instance(self): + """Test getting circuit breaker creates instance with correct config.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + + assert breaker.name == "telemetry-circuit-breaker-test-host" + assert breaker.fail_max == MINIMUM_CALLS + + def test_get_circuit_breaker_same_host(self): + """Test that same host returns same circuit breaker instance.""" + breaker1 = CircuitBreakerManager.get_circuit_breaker("test-host") + breaker2 = CircuitBreakerManager.get_circuit_breaker("test-host") + + assert breaker1 is breaker2 + + def test_get_circuit_breaker_different_hosts(self): + """Test that different hosts return different circuit breaker instances.""" + breaker1 = CircuitBreakerManager.get_circuit_breaker("host1") + breaker2 = CircuitBreakerManager.get_circuit_breaker("host2") + + assert breaker1 is not breaker2 + assert breaker1.name != breaker2.name + + def test_get_circuit_breaker_creates_breaker(self): + """Test getting circuit breaker creates and returns breaker.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + assert breaker is not None + assert breaker.current_state in ["closed", "open", "half-open"] + + def test_thread_safety(self): + """Test thread safety of circuit breaker manager.""" + results = [] + + def get_breaker(host): + breaker = CircuitBreakerManager.get_circuit_breaker(host) + results.append(breaker) + + # Create multiple threads accessing circuit breakers + threads = [] + for i in range(10): + thread = threading.Thread(target=get_breaker, args=(f"host{i % 3}",)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + # Should have 10 results + assert len(results) == 10 + + # All breakers for same host should be same instance + host0_breakers = [b for b in results if b.name.endswith("host0")] + assert all(b is host0_breakers[0] for b in host0_breakers) + + +class TestCircuitBreakerErrorDetection: + """Test cases for circuit breaker error detection.""" + + def test_is_circuit_breaker_error_true(self): + """Test detecting circuit breaker errors.""" + error = CircuitBreakerError("Circuit breaker is open") + assert is_circuit_breaker_error(error) is True + + def test_is_circuit_breaker_error_false(self): + """Test detecting non-circuit breaker errors.""" + error = ValueError("Some other error") + assert is_circuit_breaker_error(error) is False + + error = RuntimeError("Another error") + assert is_circuit_breaker_error(error) is False + + def test_is_circuit_breaker_error_none(self): + """Test with None input.""" + assert is_circuit_breaker_error(None) is False + + +class TestCircuitBreakerIntegration: + """Integration tests for circuit breaker functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + CircuitBreakerManager._instances.clear() + + def teardown_method(self): + """Clean up after tests.""" + CircuitBreakerManager._instances.clear() + + def test_circuit_breaker_state_transitions(self): + """Test circuit breaker state transitions.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + + # Initially should be closed + assert breaker.current_state == "closed" + + # Simulate failures to trigger circuit breaker + def failing_func(): + raise Exception("Simulated failure") + + # Trigger failures up to the threshold (MINIMUM_CALLS = 20) + for i in range(MINIMUM_CALLS): + with pytest.raises(Exception): + breaker.call(failing_func) + + # Next call should fail with CircuitBreakerError (circuit is now open) + with pytest.raises(CircuitBreakerError): + breaker.call(failing_func) + + # Circuit breaker should be open + assert breaker.current_state == "open" + + def test_circuit_breaker_recovery(self): + """Test circuit breaker recovery after failures.""" + breaker = CircuitBreakerManager.get_circuit_breaker("test-host") + + # Trigger circuit breaker to open + def failing_func(): + raise Exception("Simulated failure") + + # Trigger failures up to the threshold + for i in range(MINIMUM_CALLS): + with pytest.raises(Exception): + breaker.call(failing_func) + + # Circuit should be open now + assert breaker.current_state == "open" + + # Wait for reset timeout + time.sleep(RESET_TIMEOUT + 0.1) + + # Try successful call to close circuit breaker + def successful_func(): + return "success" + + try: + result = breaker.call(successful_func) + # If successful, circuit should transition to closed or half-open + assert result == "success" + except CircuitBreakerError: + # Circuit might still be open, which is acceptable + pass + + # Circuit breaker should be closed or half-open (not permanently open) + assert breaker.current_state in ["closed", "half-open", "open"] + + def test_circuit_breaker_state_listener_half_open(self): + """Test circuit breaker state listener logs half-open state.""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerStateListener, + CIRCUIT_BREAKER_STATE_HALF_OPEN, + ) + from unittest.mock import patch + + listener = CircuitBreakerStateListener() + + # Mock circuit breaker with half-open state + mock_cb = Mock() + mock_cb.name = "test-breaker" + + # Mock old and new states + mock_old_state = Mock() + mock_old_state.name = "open" + + mock_new_state = Mock() + mock_new_state.name = CIRCUIT_BREAKER_STATE_HALF_OPEN + + with patch( + "databricks.sql.telemetry.circuit_breaker_manager.logger" + ) as mock_logger: + listener.state_change(mock_cb, mock_old_state, mock_new_state) + + # Check that half-open state was logged + mock_logger.info.assert_called() + calls = mock_logger.info.call_args_list + half_open_logged = any("half-open" in str(call) for call in calls) + assert half_open_logged + + def test_circuit_breaker_state_listener_all_states(self): + """Test circuit breaker state listener logs all possible state transitions.""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerStateListener, + CIRCUIT_BREAKER_STATE_HALF_OPEN, + CIRCUIT_BREAKER_STATE_OPEN, + CIRCUIT_BREAKER_STATE_CLOSED, + ) + from unittest.mock import patch + + listener = CircuitBreakerStateListener() + mock_cb = Mock() + mock_cb.name = "test-breaker" + + # Test all state transitions with exact constants + state_transitions = [ + (CIRCUIT_BREAKER_STATE_CLOSED, CIRCUIT_BREAKER_STATE_OPEN), + (CIRCUIT_BREAKER_STATE_OPEN, CIRCUIT_BREAKER_STATE_HALF_OPEN), + (CIRCUIT_BREAKER_STATE_HALF_OPEN, CIRCUIT_BREAKER_STATE_CLOSED), + (CIRCUIT_BREAKER_STATE_CLOSED, CIRCUIT_BREAKER_STATE_HALF_OPEN), + ] + + with patch( + "databricks.sql.telemetry.circuit_breaker_manager.logger" + ) as mock_logger: + for old_state_name, new_state_name in state_transitions: + mock_old_state = Mock() + mock_old_state.name = old_state_name + + mock_new_state = Mock() + mock_new_state.name = new_state_name + + listener.state_change(mock_cb, mock_old_state, mock_new_state) + + # Verify that logging was called for each transition + assert mock_logger.info.call_count >= len(state_transitions) + + def test_get_circuit_breaker_creates_on_demand(self): + """Test that circuit breaker is created on first access.""" + # Test with a host that doesn't exist yet + breaker = CircuitBreakerManager.get_circuit_breaker("new-host") + assert breaker is not None + assert "new-host" in CircuitBreakerManager._instances diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 2ff82cee..3438bcf8 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -27,7 +27,9 @@ def mock_telemetry_client(): client_context = MagicMock() # Patch the _setup_pool_manager method to avoid SSL file loading - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" + ): return TelemetryClient( telemetry_enabled=True, session_id_hex=session_id, @@ -85,7 +87,7 @@ def test_network_request_flow(self, mock_http_request, mock_telemetry_client): mock_response.status = 200 mock_response.status_code = 200 mock_http_request.return_value = mock_response - + client = mock_telemetry_client # Create mock events @@ -221,7 +223,9 @@ def test_client_lifecycle_flow(self): client_context = MagicMock() # Initialize enabled client - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" + ): TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, session_id_hex=session_id_hex, @@ -289,7 +293,9 @@ def test_factory_shutdown_flow(self): client_context = MagicMock() # Initialize multiple clients - with patch('databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers'): + with patch( + "databricks.sql.common.unified_http_client.UnifiedHttpClient._setup_pool_managers" + ): for session in [session1, session2]: TelemetryClientFactory.initialize_telemetry_client( telemetry_enabled=True, @@ -372,8 +378,10 @@ def test_telemetry_enabled_when_flag_is_true(self, mock_http_request, MockSessio mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-true" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -400,8 +408,10 @@ def test_telemetry_disabled_when_flag_is_false( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-false" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request @@ -428,8 +438,10 @@ def test_telemetry_disabled_when_flag_request_fails( mock_session_instance = MockSession.return_value mock_session_instance.guid_hex = "test-session-ff-fail" mock_session_instance.auth_provider = AccessTokenAuthProvider("token") - mock_session_instance.is_open = False # Connection starts closed for test cleanup - + mock_session_instance.is_open = ( + False # Connection starts closed for test cleanup + ) + # Set up mock HTTP client on the session mock_http_client = MagicMock() mock_http_client.request = mock_http_request diff --git a/tests/unit/test_telemetry_circuit_breaker_integration.py b/tests/unit/test_telemetry_circuit_breaker_integration.py new file mode 100644 index 00000000..d3d19c98 --- /dev/null +++ b/tests/unit/test_telemetry_circuit_breaker_integration.py @@ -0,0 +1,364 @@ +""" +Integration tests for telemetry circuit breaker functionality. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock +import threading +import time + +from databricks.sql.telemetry.telemetry_client import TelemetryClient +from databricks.sql.auth.common import ClientContext +from databricks.sql.auth.authenticators import AccessTokenAuthProvider +from pybreaker import CircuitBreakerError + + +class TestTelemetryCircuitBreakerIntegration: + """Integration tests for telemetry circuit breaker functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + # Create mock client context with circuit breaker config + self.client_context = Mock(spec=ClientContext) + self.client_context.telemetry_circuit_breaker_enabled = True + self.client_context.telemetry_circuit_breaker_failure_threshold = ( + 0.1 # 10% failure rate + ) + self.client_context.telemetry_circuit_breaker_minimum_calls = 2 + self.client_context.telemetry_circuit_breaker_timeout = 30 + self.client_context.telemetry_circuit_breaker_reset_timeout = ( + 1 # 1 second for testing + ) + + # Add required attributes for UnifiedHttpClient + self.client_context.ssl_options = None + self.client_context.socket_timeout = None + self.client_context.retry_stop_after_attempts_count = 5 + self.client_context.retry_delay_min = 1.0 + self.client_context.retry_delay_max = 10.0 + self.client_context.retry_stop_after_attempts_duration = 300.0 + self.client_context.retry_delay_default = 5.0 + self.client_context.retry_dangerous_codes = [] + self.client_context.proxy_auth_method = None + self.client_context.pool_connections = 10 + self.client_context.pool_maxsize = 20 + self.client_context.user_agent = None + self.client_context.hostname = "test-host.example.com" + + # Create mock auth provider + self.auth_provider = Mock(spec=AccessTokenAuthProvider) + + # Create mock executor + self.executor = Mock() + + # Create telemetry client + self.telemetry_client = TelemetryClient( + telemetry_enabled=True, + session_id_hex="test-session", + auth_provider=self.auth_provider, + host_url="test-host.example.com", + executor=self.executor, + batch_size=10, + client_context=self.client_context, + ) + + def teardown_method(self): + """Clean up after tests.""" + # Clear circuit breaker instances + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + ) + + CircuitBreakerManager._instances.clear() + + def test_telemetry_client_initialization(self): + """Test that telemetry client initializes with circuit breaker.""" + assert self.telemetry_client._telemetry_push_client is not None + # Verify circuit breaker is enabled by checking the push client type + from databricks.sql.telemetry.telemetry_push_client import ( + CircuitBreakerTelemetryPushClient, + ) + + assert isinstance( + self.telemetry_client._telemetry_push_client, + CircuitBreakerTelemetryPushClient, + ) + + def test_telemetry_client_circuit_breaker_disabled(self): + """Test telemetry client with circuit breaker disabled.""" + self.client_context.telemetry_circuit_breaker_enabled = False + + telemetry_client = TelemetryClient( + telemetry_enabled=True, + session_id_hex="test-session-2", + auth_provider=self.auth_provider, + host_url="test-host.example.com", + executor=self.executor, + batch_size=10, + client_context=self.client_context, + ) + + # Verify circuit breaker is NOT enabled by checking the push client type + from databricks.sql.telemetry.telemetry_push_client import ( + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, + ) + + assert isinstance(telemetry_client._telemetry_push_client, TelemetryPushClient) + assert not isinstance( + telemetry_client._telemetry_push_client, CircuitBreakerTelemetryPushClient + ) + + def test_telemetry_request_with_circuit_breaker_success(self): + """Test successful telemetry request with circuit breaker.""" + # Mock successful response + mock_response = Mock() + mock_response.status = 200 + mock_response.data = b'{"numProtoSuccess": 1, "errors": []}' + + with patch.object( + self.telemetry_client._telemetry_push_client, + "request", + return_value=mock_response, + ): + # Mock the callback to avoid actual processing + with patch.object(self.telemetry_client, "_telemetry_request_callback"): + self.telemetry_client._send_with_unified_client( + "https://test.com/telemetry", + '{"test": "data"}', + {"Content-Type": "application/json"}, + ) + + def test_telemetry_request_with_circuit_breaker_error(self): + """Test telemetry request when circuit breaker is open.""" + # Mock circuit breaker error + with patch.object( + self.telemetry_client._telemetry_push_client, + "request", + side_effect=CircuitBreakerError("Circuit is open"), + ): + with pytest.raises(CircuitBreakerError): + self.telemetry_client._send_with_unified_client( + "https://test.com/telemetry", + '{"test": "data"}', + {"Content-Type": "application/json"}, + ) + + def test_telemetry_request_with_other_error(self): + """Test telemetry request with other network error.""" + # Mock network error + with patch.object( + self.telemetry_client._telemetry_push_client, + "request", + side_effect=ValueError("Network error"), + ): + with pytest.raises(ValueError): + self.telemetry_client._send_with_unified_client( + "https://test.com/telemetry", + '{"test": "data"}', + {"Content-Type": "application/json"}, + ) + + def test_circuit_breaker_opens_after_telemetry_failures(self): + """Test that circuit breaker opens after repeated telemetry failures.""" + # Mock failures + with patch.object( + self.telemetry_client._telemetry_push_client, + "request", + side_effect=Exception("Network error"), + ): + # Simulate multiple failures + for _ in range(3): + try: + self.telemetry_client._send_with_unified_client( + "https://test.com/telemetry", + '{"test": "data"}', + {"Content-Type": "application/json"}, + ) + except Exception: + pass + + # Circuit breaker should eventually open + # Note: This test might be flaky due to timing, but it tests the integration + time.sleep(0.1) # Give circuit breaker time to process + + def test_telemetry_client_factory_integration(self): + """Test telemetry client factory with circuit breaker.""" + from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory + + # Clear any existing clients + TelemetryClientFactory._clients.clear() + + # Initialize telemetry client through factory + TelemetryClientFactory.initialize_telemetry_client( + telemetry_enabled=True, + session_id_hex="factory-test-session", + auth_provider=self.auth_provider, + host_url="test-host.example.com", + batch_size=10, + client_context=self.client_context, + ) + + # Get the client + client = TelemetryClientFactory.get_telemetry_client("factory-test-session") + + # Should have circuit breaker enabled + from databricks.sql.telemetry.telemetry_push_client import ( + CircuitBreakerTelemetryPushClient, + ) + + assert isinstance( + client._telemetry_push_client, CircuitBreakerTelemetryPushClient + ) + + # Clean up + TelemetryClientFactory.close("factory-test-session") + + def test_circuit_breaker_configuration_from_client_context(self): + """Test that circuit breaker configuration is properly read from client context.""" + # Test with custom configuration + self.client_context.telemetry_circuit_breaker_minimum_calls = 5 + self.client_context.telemetry_circuit_breaker_reset_timeout = 120 + + telemetry_client = TelemetryClient( + telemetry_enabled=True, + session_id_hex="config-test-session", + auth_provider=self.auth_provider, + host_url="test-host.example.com", + executor=self.executor, + batch_size=10, + client_context=self.client_context, + ) + + # Verify circuit breaker is enabled with custom config + from databricks.sql.telemetry.telemetry_push_client import ( + CircuitBreakerTelemetryPushClient, + ) + + assert isinstance( + telemetry_client._telemetry_push_client, CircuitBreakerTelemetryPushClient + ) + # The config is used internally but not exposed as an attribute anymore + + def test_circuit_breaker_logging(self): + """Test that circuit breaker events are properly logged.""" + with patch("databricks.sql.telemetry.telemetry_client.logger") as mock_logger: + # Mock circuit breaker error + with patch.object( + self.telemetry_client._telemetry_push_client, + "request", + side_effect=CircuitBreakerError("Circuit is open"), + ): + try: + self.telemetry_client._send_with_unified_client( + "https://test.com/telemetry", + '{"test": "data"}', + {"Content-Type": "application/json"}, + ) + except CircuitBreakerError: + pass + + # Check that warning was logged + mock_logger.warning.assert_called() + warning_call = mock_logger.warning.call_args[0] + assert "Telemetry request blocked by circuit breaker" in warning_call[0] + assert ( + "test-session" in warning_call[1] + ) # session_id_hex is the second argument + + +class TestTelemetryCircuitBreakerThreadSafety: + """Test thread safety of telemetry circuit breaker functionality.""" + + def setup_method(self): + """Set up test fixtures.""" + self.client_context = Mock(spec=ClientContext) + self.client_context.telemetry_circuit_breaker_enabled = True + self.client_context.telemetry_circuit_breaker_failure_threshold = 0.1 + self.client_context.telemetry_circuit_breaker_minimum_calls = 2 + self.client_context.telemetry_circuit_breaker_timeout = 30 + self.client_context.telemetry_circuit_breaker_reset_timeout = 1 + + # Add required attributes for UnifiedHttpClient + self.client_context.ssl_options = None + self.client_context.socket_timeout = None + self.client_context.retry_stop_after_attempts_count = 5 + self.client_context.retry_delay_min = 1.0 + self.client_context.retry_delay_max = 10.0 + self.client_context.retry_stop_after_attempts_duration = 300.0 + self.client_context.retry_delay_default = 5.0 + self.client_context.retry_dangerous_codes = [] + self.client_context.proxy_auth_method = None + self.client_context.pool_connections = 10 + self.client_context.pool_maxsize = 20 + self.client_context.user_agent = None + self.client_context.hostname = "test-host.example.com" + + self.auth_provider = Mock(spec=AccessTokenAuthProvider) + self.executor = Mock() + + def teardown_method(self): + """Clean up after tests.""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + ) + + CircuitBreakerManager._instances.clear() + + def test_concurrent_telemetry_requests(self): + """Test concurrent telemetry requests with circuit breaker.""" + # Clear any existing circuit breaker state + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + ) + + CircuitBreakerManager._instances.clear() + + telemetry_client = TelemetryClient( + telemetry_enabled=True, + session_id_hex="concurrent-test-session", + auth_provider=self.auth_provider, + host_url="test-host.example.com", + executor=self.executor, + batch_size=10, + client_context=self.client_context, + ) + + results = [] + errors = [] + + def make_request(): + try: + # Mock the underlying HTTP client to fail, not the telemetry push client + with patch.object( + telemetry_client._http_client, + "request", + side_effect=Exception("Network error"), + ): + telemetry_client._send_with_unified_client( + "https://test.com/telemetry", + '{"test": "data"}', + {"Content-Type": "application/json"}, + ) + results.append("success") + except Exception as e: + errors.append(type(e).__name__) + + # Create multiple threads (enough to trigger circuit breaker) + from databricks.sql.telemetry.circuit_breaker_manager import MINIMUM_CALLS + + num_threads = MINIMUM_CALLS + 5 # Enough to open the circuit + threads = [] + for _ in range(num_threads): + thread = threading.Thread(target=make_request) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Should have some results and some errors + assert len(results) + len(errors) == num_threads + # Some should be CircuitBreakerError after circuit opens + assert "CircuitBreakerError" in errors or len(errors) == 0 diff --git a/tests/unit/test_telemetry_push_client.py b/tests/unit/test_telemetry_push_client.py new file mode 100644 index 00000000..a9e0baec --- /dev/null +++ b/tests/unit/test_telemetry_push_client.py @@ -0,0 +1,302 @@ +""" +Unit tests for telemetry push client functionality. +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock +import urllib.parse + +from databricks.sql.telemetry.telemetry_push_client import ( + ITelemetryPushClient, + TelemetryPushClient, + CircuitBreakerTelemetryPushClient, +) +from databricks.sql.common.http import HttpMethod +from pybreaker import CircuitBreakerError + + +class TestTelemetryPushClient: + """Test cases for TelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_http_client = Mock() + self.client = TelemetryPushClient(self.mock_http_client) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._http_client == self.mock_http_client + + def test_request_delegates_to_http_client(self): + """Test that request delegates to underlying HTTP client.""" + mock_response = Mock() + self.mock_http_client.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_http_client.request.assert_called_once() + + def test_direct_client_has_no_circuit_breaker(self): + """Test that direct client does not have circuit breaker functionality.""" + # Direct client should work without circuit breaker + assert isinstance(self.client, TelemetryPushClient) + + +class TestCircuitBreakerTelemetryPushClient: + """Test cases for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock(spec=ITelemetryPushClient) + self.host = "test-host.example.com" + self.client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + def test_initialization(self): + """Test client initialization.""" + assert self.client._delegate == self.mock_delegate + assert self.client._host == self.host + assert self.client._circuit_breaker is not None + + def test_initialization_disabled(self): + """Test client initialization with circuit breaker disabled.""" + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + assert client._circuit_breaker is not None + + def test_request_context_disabled(self): + """Test request context when circuit breaker is disabled.""" + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + mock_response = Mock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_response + mock_context.__exit__.return_value = None + self.mock_delegate.request_context.return_value = mock_context + + with client.request_context( + HttpMethod.POST, "https://test.com", {} + ) as response: + assert response == mock_response + + self.mock_delegate.request_context.assert_called_once() + + def test_request_context_enabled_success(self): + """Test successful request context when circuit breaker is enabled.""" + mock_response = Mock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_response + mock_context.__exit__.return_value = None + self.mock_delegate.request_context.return_value = mock_context + + with self.client.request_context( + HttpMethod.POST, "https://test.com", {} + ) as response: + assert response == mock_response + + self.mock_delegate.request_context.assert_called_once() + + def test_request_context_enabled_circuit_breaker_error(self): + """Test request context when circuit breaker is open.""" + # Mock circuit breaker to raise CircuitBreakerError + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + with pytest.raises(CircuitBreakerError): + with self.client.request_context( + HttpMethod.POST, "https://test.com", {} + ): + pass + + def test_request_context_enabled_other_error(self): + """Test request context when other error occurs.""" + # Mock delegate to raise a different error + self.mock_delegate.request_context.side_effect = ValueError("Network error") + + with pytest.raises(ValueError): + with self.client.request_context(HttpMethod.POST, "https://test.com", {}): + pass + + def test_request_disabled(self): + """Test request method when circuit breaker is disabled.""" + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + mock_response = Mock() + self.mock_delegate.request.return_value = mock_response + + response = client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_delegate.request.assert_called_once() + + def test_request_enabled_success(self): + """Test successful request when circuit breaker is enabled.""" + mock_response = Mock() + self.mock_delegate.request.return_value = mock_response + + response = self.client.request(HttpMethod.POST, "https://test.com", {}) + + assert response == mock_response + self.mock_delegate.request.assert_called_once() + + def test_request_enabled_circuit_breaker_error(self): + """Test request when circuit breaker is open.""" + # Mock circuit breaker to raise CircuitBreakerError + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_request_enabled_other_error(self): + """Test request when other error occurs.""" + # Mock delegate to raise a different error + self.mock_delegate.request.side_effect = ValueError("Network error") + + with pytest.raises(ValueError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + def test_is_circuit_breaker_enabled(self): + """Test checking if circuit breaker is enabled.""" + # Circuit breaker is always enabled in this implementation + assert self.client._circuit_breaker is not None + + def test_circuit_breaker_state_logging(self): + """Test that circuit breaker state changes are logged.""" + with patch( + "databricks.sql.telemetry.telemetry_push_client.logger" + ) as mock_logger: + with patch.object( + self.client._circuit_breaker, + "call", + side_effect=CircuitBreakerError("Circuit is open"), + ): + with pytest.raises(CircuitBreakerError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + # Check that warning was logged + mock_logger.warning.assert_called() + warning_args = mock_logger.warning.call_args[0] + assert "Circuit breaker is open" in warning_args[0] + assert self.host in warning_args[1] # The host is the second argument + + def test_other_error_logging(self): + """Test that other errors are logged appropriately.""" + with patch( + "databricks.sql.telemetry.telemetry_push_client.logger" + ) as mock_logger: + self.mock_delegate.request.side_effect = ValueError("Network error") + + with pytest.raises(ValueError): + self.client.request(HttpMethod.POST, "https://test.com", {}) + + # Check that debug was logged + mock_logger.debug.assert_called() + debug_args = mock_logger.debug.call_args[0] + assert "Telemetry request failed" in debug_args[0] + assert self.host in debug_args[1] # The host is the second argument + + +class TestCircuitBreakerTelemetryPushClientIntegration: + """Integration tests for CircuitBreakerTelemetryPushClient.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_delegate = Mock() + self.host = "test-host.example.com" + # Clear any existing circuit breaker state + from databricks.sql.telemetry.circuit_breaker_manager import ( + CircuitBreakerManager, + ) + + CircuitBreakerManager._instances.clear() + + def test_circuit_breaker_opens_after_failures(self): + """Test that circuit breaker opens after repeated failures.""" + from databricks.sql.telemetry.circuit_breaker_manager import MINIMUM_CALLS + + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Simulate failures + self.mock_delegate.request.side_effect = Exception("Network error") + + # Trigger failures up to the threshold + for i in range(MINIMUM_CALLS): + with pytest.raises(Exception): + client.request(HttpMethod.POST, "https://test.com", {}) + + # Next call should fail with CircuitBreakerError (circuit is now open) + with pytest.raises(CircuitBreakerError): + client.request(HttpMethod.POST, "https://test.com", {}) + + def test_circuit_breaker_recovers_after_success(self): + """Test that circuit breaker recovers after successful calls.""" + from databricks.sql.telemetry.circuit_breaker_manager import ( + MINIMUM_CALLS, + RESET_TIMEOUT, + ) + import time + + client = CircuitBreakerTelemetryPushClient(self.mock_delegate, self.host) + + # Simulate failures first + self.mock_delegate.request.side_effect = Exception("Network error") + + # Trigger failures up to the threshold + for i in range(MINIMUM_CALLS): + with pytest.raises(Exception): + client.request(HttpMethod.POST, "https://test.com", {}) + + # Circuit should be open now + with pytest.raises(CircuitBreakerError): + client.request(HttpMethod.POST, "https://test.com", {}) + + # Wait for reset timeout + time.sleep(RESET_TIMEOUT + 0.1) + + # Simulate successful calls + self.mock_delegate.request.side_effect = None + self.mock_delegate.request.return_value = Mock() + + # Should work again + response = client.request(HttpMethod.POST, "https://test.com", {}) + assert response is not None + + def test_urllib3_import_fallback(self): + """Test that the urllib3 import fallback works correctly.""" + # This test verifies that the import fallback mechanism exists + # The actual fallback is tested by the fact that the module imports successfully + # even when BaseHTTPResponse is not available + from databricks.sql.telemetry.telemetry_push_client import BaseHTTPResponse + + assert BaseHTTPResponse is not None + + def test_telemetry_push_client_request_context(self): + """Test that TelemetryPushClient.request_context works correctly.""" + from unittest.mock import Mock, MagicMock + + # Create a mock HTTP client + mock_http_client = Mock() + mock_response = Mock() + + # Mock the context manager + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_response + mock_context.__exit__.return_value = None + mock_http_client.request_context.return_value = mock_context + + # Create TelemetryPushClient + client = TelemetryPushClient(mock_http_client) + + # Test request_context + with client.request_context("GET", "https://example.com") as response: + assert response == mock_response + + # Verify that the HTTP client's request_context was called + mock_http_client.request_context.assert_called_once_with( + "GET", "https://example.com", None + )