Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Delay connection to Snowflake until needed #58

Merged
merged 4 commits into from
Jan 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Do not start connection upon instantiating `SnowflakeConnector` until its methods are called - [#58](https://github.com/PrefectHQ/prefect-snowflake/pull/58)

### Deprecated

### Removed
Expand Down
31 changes: 19 additions & 12 deletions prefect_snowflake/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,19 +147,15 @@ def get_connection(self, **connect_kwargs: Dict[str, Any]) -> SnowflakeConnectio
"schema": self.schema_,
}
connection = self.credentials.get_client(**connect_kwargs, **connect_params)
self._connection = connection
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
self.logger.info("Started a new connection to Snowflake.")
return connection

def _start_connection(self):
"""
Starts Snowflake database connection.
"""
self._connection = self.get_connection()

def block_initialization(self) -> None:
super().block_initialization()
if self._connection is None:
self._start_connection()

self.get_connection()
if self._unique_cursors is None:
self._unique_cursors = {}

Expand All @@ -174,6 +170,8 @@ def _get_cursor(self, inputs: Dict[str, Any]) -> Tuple[bool, SnowflakeCursor]:
Returns:
Whether a cursor is new and a Snowflake cursor.
"""
self._start_connection()

input_hash = hash_objects(inputs)
if input_hash is None:
raise RuntimeError(
Expand Down Expand Up @@ -232,6 +230,10 @@ def reset_cursors(self) -> None:
print(conn.fetch_one("SELECT * FROM customers")) # should be Ford again
```
""" # noqa
if not self._unique_cursors:
self.logger.info("There were no cursors to reset.")
return

input_hashes = tuple(self._unique_cursors.keys())
for input_hash in input_hashes:
cursor = self._unique_cursors.pop(input_hash)
Expand Down Expand Up @@ -462,6 +464,8 @@ async def execute(
)
```
""" # noqa
self._start_connection()

inputs = dict(
command=operation,
params=parameters,
Expand Down Expand Up @@ -506,6 +510,8 @@ async def execute_many(
)
```
""" # noqa
self._start_connection()

inputs = dict(
command=operation,
seqparams=seq_of_parameters,
Expand All @@ -523,10 +529,12 @@ def close(self):
try:
self.reset_cursors()
finally:
if self._connection is not None:
self._connection.close()
self._connection = None
self.logger.info("Successfully closed the Snowflake connection.")
if self._connection is None:
self.logger.info("There was no connection open to be closed.")
return
self._connection.close()
self._connection = None
self.logger.info("Successfully closed the Snowflake connection.")

def __enter__(self):
"""
Expand All @@ -549,7 +557,6 @@ def __getstate__(self):
def __setstate__(self, data: dict):
"""Reset connection and cursors upon loading."""
self.__dict__.update(data)
self._unique_cursors = {}
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
self._start_connection()


Expand Down
26 changes: 19 additions & 7 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,20 @@ def snowflake_connector(self, connector_params, snowflake_connect_mock):
return connector

def test_block_initialization(self, snowflake_connector):
assert snowflake_connector._connection is not None
assert snowflake_connector._unique_cursors == {}
assert snowflake_connector._connection is None
assert snowflake_connector._unique_cursors is None

def test_get_connection(self, snowflake_connector: SnowflakeConnector):
def test_get_connection(self, snowflake_connector: SnowflakeConnector, caplog):
connection = snowflake_connector.get_connection()
assert snowflake_connector._connection is connection
assert caplog.records[0].msg == "Started a new connection to Snowflake."

def test_reset_cursors(self, snowflake_connector: SnowflakeConnector):
def test_reset_cursors(self, snowflake_connector: SnowflakeConnector, caplog):
mock_cursor = MagicMock()
snowflake_connector.reset_cursors()
assert caplog.records[0].msg == "There were no cursors to reset."

snowflake_connector._start_connection()
snowflake_connector._unique_cursors["12345"] = mock_cursor
snowflake_connector.reset_cursors()
assert len(snowflake_connector._unique_cursors) == 0
Expand Down Expand Up @@ -231,13 +236,20 @@ def test_execute_Many(self, snowflake_connector: SnowflakeConnector):
is None
)

def test_close(self, snowflake_connector: SnowflakeConnector):
def test_close(self, snowflake_connector: SnowflakeConnector, caplog):
assert snowflake_connector.close() is None
assert caplog.records[0].msg == "There were no cursors to reset."
assert caplog.records[1].msg == "There was no connection open to be closed."

snowflake_connector._start_connection()
assert snowflake_connector.close() is None
assert snowflake_connector._connection is None
assert snowflake_connector._unique_cursors == {}

def test_context_management(self, snowflake_connector):
with snowflake_connector:
pass
assert snowflake_connector._connection is None
assert snowflake_connector._unique_cursors is None

assert snowflake_connector._connection is None
assert snowflake_connector._unique_cursors == {}
assert snowflake_connector._unique_cursors is None