diff --git a/CHANGELOG.md b/CHANGELOG.md index 6026add..3b54073 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Do not start connection upon instantiating `SnowflakeConnector` until its methods are called - [#58](https://github.com/PrefectHQ/prefect-snowflake/pull/58) + ### Deprecated ### Removed diff --git a/prefect_snowflake/database.py b/prefect_snowflake/database.py index 298521d..4e4297b 100644 --- a/prefect_snowflake/database.py +++ b/prefect_snowflake/database.py @@ -147,19 +147,15 @@ def get_connection(self, **connect_kwargs: Dict[str, Any]) -> SnowflakeConnectio "schema": self.schema_, } connection = self.credentials.get_client(**connect_kwargs, **connect_params) + self._connection = connection + self.logger.info("Started a new connection to Snowflake.") return connection def _start_connection(self): """ Starts Snowflake database connection. """ - self._connection = self.get_connection() - - def block_initialization(self) -> None: - super().block_initialization() - if self._connection is None: - self._start_connection() - + self.get_connection() if self._unique_cursors is None: self._unique_cursors = {} @@ -174,6 +170,8 @@ def _get_cursor(self, inputs: Dict[str, Any]) -> Tuple[bool, SnowflakeCursor]: Returns: Whether a cursor is new and a Snowflake cursor. """ + self._start_connection() + input_hash = hash_objects(inputs) if input_hash is None: raise RuntimeError( @@ -232,6 +230,10 @@ def reset_cursors(self) -> None: print(conn.fetch_one("SELECT * FROM customers")) # should be Ford again ``` """ # noqa + if not self._unique_cursors: + self.logger.info("There were no cursors to reset.") + return + input_hashes = tuple(self._unique_cursors.keys()) for input_hash in input_hashes: cursor = self._unique_cursors.pop(input_hash) @@ -462,6 +464,8 @@ async def execute( ) ``` """ # noqa + self._start_connection() + inputs = dict( command=operation, params=parameters, @@ -506,6 +510,8 @@ async def execute_many( ) ``` """ # noqa + self._start_connection() + inputs = dict( command=operation, seqparams=seq_of_parameters, @@ -523,10 +529,12 @@ def close(self): try: self.reset_cursors() finally: - if self._connection is not None: - self._connection.close() - self._connection = None - self.logger.info("Successfully closed the Snowflake connection.") + if self._connection is None: + self.logger.info("There was no connection open to be closed.") + return + self._connection.close() + self._connection = None + self.logger.info("Successfully closed the Snowflake connection.") def __enter__(self): """ @@ -549,7 +557,6 @@ def __getstate__(self): def __setstate__(self, data: dict): """Reset connection and cursors upon loading.""" self.__dict__.update(data) - self._unique_cursors = {} self._start_connection() diff --git a/tests/test_database.py b/tests/test_database.py index 1ee660d..126b051 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -192,15 +192,20 @@ def snowflake_connector(self, connector_params, snowflake_connect_mock): return connector def test_block_initialization(self, snowflake_connector): - assert snowflake_connector._connection is not None - assert snowflake_connector._unique_cursors == {} + assert snowflake_connector._connection is None + assert snowflake_connector._unique_cursors is None - def test_get_connection(self, snowflake_connector: SnowflakeConnector): + def test_get_connection(self, snowflake_connector: SnowflakeConnector, caplog): connection = snowflake_connector.get_connection() assert snowflake_connector._connection is connection + assert caplog.records[0].msg == "Started a new connection to Snowflake." - def test_reset_cursors(self, snowflake_connector: SnowflakeConnector): + def test_reset_cursors(self, snowflake_connector: SnowflakeConnector, caplog): mock_cursor = MagicMock() + snowflake_connector.reset_cursors() + assert caplog.records[0].msg == "There were no cursors to reset." + + snowflake_connector._start_connection() snowflake_connector._unique_cursors["12345"] = mock_cursor snowflake_connector.reset_cursors() assert len(snowflake_connector._unique_cursors) == 0 @@ -231,13 +236,20 @@ def test_execute_Many(self, snowflake_connector: SnowflakeConnector): is None ) - def test_close(self, snowflake_connector: SnowflakeConnector): + def test_close(self, snowflake_connector: SnowflakeConnector, caplog): + assert snowflake_connector.close() is None + assert caplog.records[0].msg == "There were no cursors to reset." + assert caplog.records[1].msg == "There was no connection open to be closed." + + snowflake_connector._start_connection() assert snowflake_connector.close() is None assert snowflake_connector._connection is None assert snowflake_connector._unique_cursors == {} def test_context_management(self, snowflake_connector): with snowflake_connector: - pass + assert snowflake_connector._connection is None + assert snowflake_connector._unique_cursors is None + assert snowflake_connector._connection is None - assert snowflake_connector._unique_cursors == {} + assert snowflake_connector._unique_cursors is None