From 1a6afda16f3657c41e879c2275a756e904398811 Mon Sep 17 00:00:00 2001 From: Zexin Yao Date: Thu, 27 Jun 2024 21:36:52 +0000 Subject: [PATCH] SNOW-1507212 create an interface for Snowflake restful class(es) Description - create an interface SnowflakeRestfulInterface - both the client side restful class and the server side restful class shall conform to this interface - update the type annotation to use SnowflakeRestfulInterface instead of the concrete class SnowflakeRestful Testing --- DESCRIPTION.md | 3 + src/snowflake/connector/connection.py | 3 +- src/snowflake/connector/network.py | 7 +- .../connector/snowflake_restful_interface.py | 142 ++++++++++++++++++ src/snowflake/connector/telemetry.py | 6 +- test/unit/test_session_manager.py | 5 +- 6 files changed, 157 insertions(+), 9 deletions(-) create mode 100644 src/snowflake/connector/snowflake_restful_interface.py diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 4d7fd13fc..521729ebd 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -8,6 +8,9 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne # Release Notes +- v3.11.1(TBD) + - add an interface class for SnowflakeRestful + - v3.11.0(June 17,2024) - Added support for `token_file_path` connection parameter to read an OAuth token from a file when connecting to Snowflake. - Added support for `debug_arrow_chunk` connection parameter to allow debugging raw arrow data in case of arrow data parsing failure. diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 7137036d5..20246e1c5 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -103,6 +103,7 @@ ReauthenticationRequest, SnowflakeRestful, ) +from .snowflake_restful_interface import SnowflakeRestfulInterface from .sqlstate import SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_FEATURE_NOT_SUPPORTED from .telemetry import TelemetryClient, TelemetryData, TelemetryField from .telemetry_oob import TelemetryService @@ -584,7 +585,7 @@ def client_prefetch_threads(self, value) -> None: self._validate_client_prefetch_threads() @property - def rest(self) -> SnowflakeRestful | None: + def rest(self) -> SnowflakeRestfulInterface | None: return self._rest @property diff --git a/src/snowflake/connector/network.py b/src/snowflake/connector/network.py index 26c0544eb..64d8d814e 100644 --- a/src/snowflake/connector/network.py +++ b/src/snowflake/connector/network.py @@ -86,6 +86,7 @@ ServiceUnavailableError, TooManyRequests, ) +from .snowflake_restful_interface import SnowflakeRestfulInterface from .sqlstate import ( SQLSTATE_CONNECTION_NOT_EXISTS, SQLSTATE_CONNECTION_REJECTED, @@ -319,11 +320,11 @@ def __call__(self, r: PreparedRequest) -> PreparedRequest: class SessionPool: - def __init__(self, rest: SnowflakeRestful) -> None: + def __init__(self, rest: SnowflakeRestfulInterface) -> None: # A stack of the idle sessions self._idle_sessions: list[Session] = [] self._active_sessions: set[Session] = set() - self._rest: SnowflakeRestful = rest + self._rest: SnowflakeRestfulInterface = rest def get_session(self) -> Session: """Returns a session from the session pool or creates a new one.""" @@ -361,7 +362,7 @@ def close(self) -> None: self._idle_sessions.clear() -class SnowflakeRestful: +class SnowflakeRestful(SnowflakeRestfulInterface): """Snowflake Restful class.""" def __init__( diff --git a/src/snowflake/connector/snowflake_restful_interface.py b/src/snowflake/connector/snowflake_restful_interface.py new file mode 100644 index 000000000..1f2238108 --- /dev/null +++ b/src/snowflake/connector/snowflake_restful_interface.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from .connection import SnowflakeConnection + from .vendored.requests import Session + + +class SnowflakeRestfulInterface(ABC): + """Snowflake Restful Interface + + Defines all the interfaces that we expose in the Snowflake restful classes. Both the client side restful class and + the server side one shall conform to this interface. And whenever we introduce a new public method, it should be + defined in this interface, and implemented in both restful classes. + """ + + @property + @abstractmethod + def token(self) -> str | None: + pass + + @property + @abstractmethod + def master_token(self) -> str | None: + pass + + @property + @abstractmethod + def master_validity_in_seconds(self) -> int: + pass + + @master_validity_in_seconds.setter + @abstractmethod + def master_validity_in_seconds(self, value) -> None: + pass + + @property + @abstractmethod + def id_token(self): + pass + + @id_token.setter + @abstractmethod + def id_token(self, value) -> None: + pass + + @property + @abstractmethod + def mfa_token(self) -> str | None: + pass + + @mfa_token.setter + @abstractmethod + def mfa_token(self, value: str) -> None: + pass + + @property + @abstractmethod + def server_url(self) -> str: + pass + + @abstractmethod + def close(self) -> None: + pass + + @abstractmethod + def request( + self, + url, + body=None, + method: str = "post", + client: str = "sfsql", + timeout: int | None = None, + _no_results: bool = False, + _include_retry_params: bool = False, + _no_retry: bool = False, + ): + pass + + @abstractmethod + def update_tokens( + self, + session_token, + master_token, + master_validity_in_seconds=None, + id_token=None, + mfa_token=None, + ) -> None: + """Updates session and master tokens and optionally temporary credential.""" + pass + + @abstractmethod + def delete_session(self, retry: bool = False) -> None: + """Deletes the session.""" + pass + + @abstractmethod + def fetch( + self, + method: str, + full_url: str, + headers: dict[str, Any], + data: dict[str, Any] | None = None, + timeout: int | None = None, + **kwargs, + ) -> dict[Any, Any]: + """Carry out API request with session management.""" + pass + + @staticmethod + @abstractmethod + def add_request_guid(full_url: str) -> str: + """Adds request_guid parameter for HTTP request tracing.""" + pass + + @abstractmethod + def log_and_handle_http_error_with_cause( + self, + e: Exception, + full_url: str, + method: str, + retry_timeout: int, + retry_count: int, + conn: SnowflakeConnection, + timed_out: bool = True, + ) -> None: + pass + + @abstractmethod + def handle_invalid_certificate_error(self, conn, full_url, cause) -> None: + pass + + @abstractmethod + def make_requests_session(self) -> Session: + pass diff --git a/src/snowflake/connector/telemetry.py b/src/snowflake/connector/telemetry.py index 933fc489a..45bf8444d 100644 --- a/src/snowflake/connector/telemetry.py +++ b/src/snowflake/connector/telemetry.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: from .connection import SnowflakeConnection - from .network import SnowflakeRestful + from .snowflake_restful_interface import SnowflakeRestfulInterface logger = logging.getLogger(__name__) @@ -131,8 +131,8 @@ class TelemetryClient: SF_PATH_TELEMETRY = "/telemetry/send" DEFAULT_FORCE_FLUSH_SIZE = 100 - def __init__(self, rest: SnowflakeRestful, flush_size=None) -> None: - self._rest: SnowflakeRestful | None = rest + def __init__(self, rest: SnowflakeRestfulInterface, flush_size=None) -> None: + self._rest: SnowflakeRestfulInterface | None = rest self._log_batch = [] self._flush_size = flush_size or TelemetryClient.DEFAULT_FORCE_FLUSH_SIZE self._lock = Lock() diff --git a/test/unit/test_session_manager.py b/test/unit/test_session_manager.py index 73487c588..f5e492368 100644 --- a/test/unit/test_session_manager.py +++ b/test/unit/test_session_manager.py @@ -9,6 +9,7 @@ from unittest import mock from snowflake.connector.network import SnowflakeRestful +from snowflake.connector.snowflake_restful_interface import SnowflakeRestfulInterface try: from snowflake.connector.ssl_wrap_socket import DEFAULT_OCSP_MODE @@ -32,7 +33,7 @@ class OCSPMode(Enum): mock_conn._ocsp_mode = lambda: DEFAULT_OCSP_MODE -def close_sessions(rest: SnowflakeRestful, num_session_pools: int) -> None: +def close_sessions(rest: SnowflakeRestfulInterface, num_session_pools: int) -> None: """Helper function to call SnowflakeRestful.close(). Asserts close was called on all SessionPools.""" with mock.patch("snowflake.connector.network.SessionPool.close") as close_mock: rest.close() @@ -40,7 +41,7 @@ def close_sessions(rest: SnowflakeRestful, num_session_pools: int) -> None: def create_session( - rest: SnowflakeRestful, num_sessions: int = 1, url: str | None = None + rest: SnowflakeRestfulInterface, num_sessions: int = 1, url: str | None = None ) -> None: """ Creates 'num_sessions' sessions to 'url'. This is recursive so that idle sessions