diff --git a/.changes/unreleased/Features-20230131-022659.yaml b/.changes/unreleased/Features-20230131-022659.yaml new file mode 100644 index 000000000..c5d01c71f --- /dev/null +++ b/.changes/unreleased/Features-20230131-022659.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Add configurable setting for keeping connections open on Snowflake +time: 2023-01-31T02:26:59.701589-08:00 +custom: + Author: versusfacit joshuataylor + Issue: "854" diff --git a/dbt/adapters/snowflake/connections.py b/dbt/adapters/snowflake/connections.py index 69f654c16..e020c47ed 100644 --- a/dbt/adapters/snowflake/connections.py +++ b/dbt/adapters/snowflake/connections.py @@ -78,6 +78,7 @@ class SnowflakeCredentials(Credentials): retry_on_database_errors: bool = False retry_all: bool = False insecure_mode: Optional[bool] = False + reuse_connections: Optional[bool] = None def __post_init__(self): if self.authenticator != "oauth" and ( @@ -151,6 +152,7 @@ def auth_args(self): result["client_store_temporary_credential"] = True # enable mfa token cache for linux result["client_request_mfa_token"] = True + result["reuse_connections"] = self.reuse_connections result["private_key"] = self._get_private_key() return result @@ -486,3 +488,12 @@ def add_query(self, sql, auto_begin=True, bindings=None, abridge_sql_log=False): ) return connection, cursor + + def release(self) -> None: + """Reuse connections by deferring release until adapter context manager in core + resets adapters. This cleanup_all happens before Python teardown. Idle connections + incur no costs while waiting in the connection pool.""" + if self.profile.credentials.reuse_connections: # type: ignore + return + else: + super().release() diff --git a/tests/unit/test_snowflake_adapter.py b/tests/unit/test_snowflake_adapter.py index 7b7ab4101..49186206e 100644 --- a/tests/unit/test_snowflake_adapter.py +++ b/tests/unit/test_snowflake_adapter.py @@ -262,7 +262,7 @@ def test_client_session_keep_alive_false_by_default(self): client_session_keep_alive=False, database='test_database', role=None, schema='public', user='test_user', warehouse='test_warehouse', private_key=None, application='dbt', insecure_mode=False, - session_parameters={}), + session_parameters={}, reuse_connections=None), ]) def test_client_session_keep_alive_true(self): @@ -279,7 +279,7 @@ def test_client_session_keep_alive_true(self): client_session_keep_alive=True, database='test_database', role=None, schema='public', user='test_user', warehouse='test_warehouse', private_key=None, application='dbt', insecure_mode=False, - session_parameters={}) + session_parameters={}, reuse_connections=None) ]) def test_user_pass_authentication(self): @@ -298,7 +298,7 @@ def test_user_pass_authentication(self): password='test_password', role=None, schema='public', user='test_user', warehouse='test_warehouse', private_key=None, application='dbt', insecure_mode=False, - session_parameters={}) + session_parameters={}, reuse_connections=None) ]) def test_authenticator_user_pass_authentication(self): @@ -318,9 +318,9 @@ def test_authenticator_user_pass_authentication(self): password='test_password', role=None, schema='public', user='test_user', warehouse='test_warehouse', authenticator='test_sso_url', private_key=None, - application='dbt', client_request_mfa_token=True, + application='dbt', client_request_mfa_token=True, client_store_temporary_credential=True, insecure_mode=False, - session_parameters={}) + session_parameters={}, reuse_connections=None) ]) def test_authenticator_externalbrowser_authentication(self): @@ -338,9 +338,9 @@ def test_authenticator_externalbrowser_authentication(self): client_session_keep_alive=False, database='test_database', role=None, schema='public', user='test_user', warehouse='test_warehouse', authenticator='externalbrowser', - private_key=None, application='dbt', client_request_mfa_token=True, + private_key=None, application='dbt', client_request_mfa_token=True, client_store_temporary_credential=True, insecure_mode=False, - session_parameters={}) + session_parameters={}, reuse_connections=None) ]) def test_authenticator_oauth_authentication(self): @@ -359,9 +359,9 @@ def test_authenticator_oauth_authentication(self): client_session_keep_alive=False, database='test_database', role=None, schema='public', user='test_user', warehouse='test_warehouse', authenticator='oauth', token='my-oauth-token', - private_key=None, application='dbt', client_request_mfa_token=True, + private_key=None, application='dbt', client_request_mfa_token=True, client_store_temporary_credential=True, insecure_mode=False, - session_parameters={}) + session_parameters={}, reuse_connections=None) ]) @mock.patch('dbt.adapters.snowflake.SnowflakeCredentials._get_private_key', return_value='test_key') @@ -383,7 +383,7 @@ def test_authenticator_private_key_authentication(self, mock_get_private_key): role=None, schema='public', user='test_user', warehouse='test_warehouse', private_key='test_key', application='dbt', insecure_mode=False, - session_parameters={}) + session_parameters={}, reuse_connections=None) ]) @mock.patch('dbt.adapters.snowflake.SnowflakeCredentials._get_private_key', return_value='test_key') @@ -405,7 +405,7 @@ def test_authenticator_private_key_authentication_no_passphrase(self, mock_get_p role=None, schema='public', user='test_user', warehouse='test_warehouse', private_key='test_key', application='dbt', insecure_mode=False, - session_parameters={}) + session_parameters={}, reuse_connections=None) ]) def test_query_tag(self): @@ -422,7 +422,26 @@ def test_query_tag(self): password='test_password', role=None, schema='public', user='test_user', warehouse='test_warehouse', private_key=None, application='dbt', insecure_mode=False, - session_parameters={"QUERY_TAG": "test_query_tag"}) + session_parameters={"QUERY_TAG": "test_query_tag"}, reuse_connections=None) + ]) + + def test_reuse_connections_with_keep_alive(self): + self.config.credentials = self.config.credentials.replace( + reuse_connections=True, + client_session_keep_alive=True + ) + self.adapter = SnowflakeAdapter(self.config) + conn = self.adapter.connections.set_connection_name(name='new_connection_with_new_config') + + self.snowflake.assert_not_called() + conn.handle + self.snowflake.assert_has_calls([ + mock.call( + account='test_account', autocommit=True, + client_session_keep_alive=True, database='test_database', + role=None, schema='public', user='test_user', warehouse='test_warehouse', + private_key=None, application='dbt', insecure_mode=False, + session_parameters={}, reuse_connections=True) ])