Skip to content

Commit 42c90f9

Browse files
committed
Add tests for federated auth token fetch from webapp
1 parent 8d31236 commit 42c90f9

File tree

2 files changed

+223
-0
lines changed

2 files changed

+223
-0
lines changed

deepnote_toolkit/sql/sql_execution.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ def _generate_temporary_credentials(integration_id):
267267
def _get_federated_auth_credentials(
268268
integration_id: str, user_pod_auth_context_token: str
269269
) -> FederatedAuthResponseData:
270+
"""Get federated auth credentials for the given integration ID and user pod auth context token."""
271+
270272
url = get_absolute_userpod_api_url(
271273
f"integrations/federated-auth-token/{integration_id}"
272274
)
@@ -327,6 +329,8 @@ def _handle_federated_auth_params(sql_alchemy_dict: dict[str, Any]) -> None:
327329
] = f"Bearer {federated_auth.accessToken}"
328330
elif federated_auth.integrationType == "big-query":
329331
sql_alchemy_dict["params"]["access_token"] = federated_auth.accessToken
332+
elif federated_auth.integrationType == "snowflake":
333+
logger.warning("Snowflake federated auth is not supported yet, using the original connection URL")
330334
else:
331335
logger.error(
332336
"Unsupported integration type: %s, try updating toolkit version",

tests/unit/test_sql_execution.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,3 +585,222 @@ def test_all_dataframes_serialize_to_parquet(self, key, df):
585585
df_cleaned.to_parquet(in_memory_file)
586586
except: # noqa: E722
587587
self.fail(f"serializing to parquet failed for {key}")
588+
589+
590+
class TestFederatedAuth(unittest.TestCase):
591+
@mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials")
592+
def test_federated_auth_params_trino(self, mock_get_credentials):
593+
"""Test that Trino federated auth updates the Authorization header with Bearer token."""
594+
from deepnote_toolkit.sql.sql_execution import (
595+
FederatedAuthResponseData,
596+
_handle_federated_auth_params,
597+
)
598+
599+
# Setup mock to return Trino credentials
600+
mock_get_credentials.return_value = FederatedAuthResponseData(
601+
integrationType="trino",
602+
accessToken="test-trino-access-token",
603+
)
604+
605+
# Create a sql_alchemy_dict with federatedAuthParams and the expected structure
606+
sql_alchemy_dict = {
607+
"url": "trino://user@localhost:8080/catalog",
608+
"params": {
609+
"connect_args": {
610+
"http_headers": {
611+
"Authorization": "Bearer old-token",
612+
}
613+
}
614+
},
615+
"federatedAuthParams": {
616+
"integrationId": "test-integration-id",
617+
"authContextToken": "test-auth-context-token",
618+
},
619+
}
620+
621+
# Call the function
622+
_handle_federated_auth_params(sql_alchemy_dict)
623+
624+
# Verify the API was called with correct params
625+
mock_get_credentials.assert_called_once_with(
626+
"test-integration-id", "test-auth-context-token"
627+
)
628+
629+
# Verify the Authorization header was updated with the new token
630+
self.assertEqual(
631+
sql_alchemy_dict["params"]["connect_args"]["http_headers"]["Authorization"],
632+
"Bearer test-trino-access-token",
633+
)
634+
635+
@mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials")
636+
def test_federated_auth_params_bigquery(self, mock_get_credentials):
637+
"""Test that BigQuery federated auth updates the access_token in params."""
638+
from deepnote_toolkit.sql.sql_execution import (
639+
FederatedAuthResponseData,
640+
_handle_federated_auth_params,
641+
)
642+
643+
# Setup mock to return BigQuery credentials
644+
mock_get_credentials.return_value = FederatedAuthResponseData(
645+
integrationType="big-query",
646+
accessToken="test-bigquery-access-token",
647+
)
648+
649+
# Create a sql_alchemy_dict with federatedAuthParams
650+
sql_alchemy_dict = {
651+
"url": "bigquery://?user_supplied_client=true",
652+
"params": {
653+
"access_token": "old-access-token",
654+
"project": "test-project",
655+
},
656+
"federatedAuthParams": {
657+
"integrationId": "test-bigquery-integration-id",
658+
"authContextToken": "test-bigquery-auth-context-token",
659+
},
660+
}
661+
662+
# Call the function
663+
_handle_federated_auth_params(sql_alchemy_dict)
664+
665+
# Verify the API was called with correct params
666+
mock_get_credentials.assert_called_once_with(
667+
"test-bigquery-integration-id", "test-bigquery-auth-context-token"
668+
)
669+
670+
# Verify the access_token was updated with the new token
671+
self.assertEqual(
672+
sql_alchemy_dict["params"]["access_token"],
673+
"test-bigquery-access-token",
674+
)
675+
676+
@mock.patch("deepnote_toolkit.sql.sql_execution.logger")
677+
@mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials")
678+
def test_federated_auth_params_snowflake(self, mock_get_credentials, mock_logger):
679+
"""Test that Snowflake federated auth logs a warning since it's not supported yet."""
680+
from deepnote_toolkit.sql.sql_execution import (
681+
FederatedAuthResponseData,
682+
_handle_federated_auth_params,
683+
)
684+
685+
# Setup mock to return Snowflake credentials
686+
mock_get_credentials.return_value = FederatedAuthResponseData(
687+
integrationType="snowflake",
688+
accessToken="test-snowflake-access-token",
689+
)
690+
691+
# Create a sql_alchemy_dict with federatedAuthParams
692+
sql_alchemy_dict = {
693+
"url": "snowflake://test@test?warehouse=&role=&application=Deepnote_Workspaces",
694+
"params": {},
695+
"federatedAuthParams": {
696+
"integrationId": "test-snowflake-integration-id",
697+
"authContextToken": "test-snowflake-auth-context-token",
698+
},
699+
}
700+
701+
# Store original params to verify they remain unchanged
702+
original_params = sql_alchemy_dict["params"].copy()
703+
704+
# Call the function
705+
_handle_federated_auth_params(sql_alchemy_dict)
706+
707+
# Verify the API was called with correct params
708+
mock_get_credentials.assert_called_once_with(
709+
"test-snowflake-integration-id", "test-snowflake-auth-context-token"
710+
)
711+
712+
# Verify a warning was logged
713+
mock_logger.warning.assert_called_once_with(
714+
"Snowflake federated auth is not supported yet, using the original connection URL"
715+
)
716+
717+
# Verify params were NOT modified (snowflake is not supported yet)
718+
self.assertEqual(sql_alchemy_dict["params"], original_params)
719+
720+
def test_federated_auth_params_not_present(self):
721+
"""Test that no action is taken when federatedAuthParams is not present."""
722+
from deepnote_toolkit.sql.sql_execution import _handle_federated_auth_params
723+
724+
# Create a sql_alchemy_dict without federatedAuthParams
725+
sql_alchemy_dict = {
726+
"url": "trino://user@localhost:8080/catalog",
727+
"params": {
728+
"connect_args": {
729+
"http_headers": {"Authorization": "Bearer original-token"}
730+
}
731+
},
732+
}
733+
734+
original_dict = json.loads(json.dumps(sql_alchemy_dict))
735+
736+
# Call the function
737+
_handle_federated_auth_params(sql_alchemy_dict)
738+
739+
# Verify the dict was not modified
740+
self.assertEqual(sql_alchemy_dict, original_dict)
741+
742+
@mock.patch("deepnote_toolkit.sql.sql_execution.logger")
743+
def test_federated_auth_params_invalid_params(self, mock_logger):
744+
"""Test that invalid federated auth params logs an error and returns early."""
745+
from deepnote_toolkit.sql.sql_execution import _handle_federated_auth_params
746+
747+
# Create a sql_alchemy_dict with invalid federatedAuthParams (missing required fields)
748+
sql_alchemy_dict = {
749+
"url": "trino://user@localhost:8080/catalog",
750+
"params": {},
751+
"federatedAuthParams": {
752+
"invalidField": "value",
753+
},
754+
}
755+
756+
original_dict = json.loads(json.dumps(sql_alchemy_dict))
757+
758+
# Call the function
759+
_handle_federated_auth_params(sql_alchemy_dict)
760+
761+
# Verify an error was logged
762+
mock_logger.error.assert_called_once()
763+
call_args = mock_logger.error.call_args
764+
self.assertIn("Invalid federated auth params", call_args[0][0])
765+
766+
self.assertEqual(sql_alchemy_dict, original_dict)
767+
768+
@mock.patch("deepnote_toolkit.sql.sql_execution.logger")
769+
@mock.patch("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials")
770+
def test_federated_auth_params_unsupported_integration_type(
771+
self, mock_get_credentials, mock_logger
772+
):
773+
"""Test that unsupported integration type logs an error."""
774+
from deepnote_toolkit.sql.sql_execution import (
775+
FederatedAuthResponseData,
776+
_handle_federated_auth_params,
777+
)
778+
779+
# Setup mock to return unknown integration type
780+
mock_get_credentials.return_value = FederatedAuthResponseData(
781+
integrationType="unknown-database",
782+
accessToken="test-token",
783+
)
784+
785+
# Create a sql_alchemy_dict with federatedAuthParams
786+
sql_alchemy_dict = {
787+
"url": "unknown://host/db",
788+
"params": {},
789+
"federatedAuthParams": {
790+
"integrationId": "test-integration-id",
791+
"authContextToken": "test-auth-context-token",
792+
},
793+
}
794+
795+
original_dict = json.loads(json.dumps(sql_alchemy_dict))
796+
797+
# Call the function
798+
_handle_federated_auth_params(sql_alchemy_dict)
799+
800+
# Verify an error was logged for unsupported integration type
801+
mock_logger.error.assert_called_once_with(
802+
"Unsupported integration type: %s, try updating toolkit version",
803+
"unknown-database",
804+
)
805+
806+
self.assertEqual(sql_alchemy_dict, original_dict)

0 commit comments

Comments
 (0)