@@ -585,3 +585,222 @@ def test_all_dataframes_serialize_to_parquet(self, key, df):
585585 df_cleaned .to_parquet (in_memory_file )
586586 except : # noqa: E722
587587 self .fail (f"serializing to parquet failed for { key } " )
588+
589+
590+ class TestFederatedAuth (unittest .TestCase ):
591+ @mock .patch ("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials" )
592+ def test_federated_auth_params_trino (self , mock_get_credentials ):
593+ """Test that Trino federated auth updates the Authorization header with Bearer token."""
594+ from deepnote_toolkit .sql .sql_execution import (
595+ FederatedAuthResponseData ,
596+ _handle_federated_auth_params ,
597+ )
598+
599+ # Setup mock to return Trino credentials
600+ mock_get_credentials .return_value = FederatedAuthResponseData (
601+ integrationType = "trino" ,
602+ accessToken = "test-trino-access-token" ,
603+ )
604+
605+ # Create a sql_alchemy_dict with federatedAuthParams and the expected structure
606+ sql_alchemy_dict = {
607+ "url" : "trino://user@localhost:8080/catalog" ,
608+ "params" : {
609+ "connect_args" : {
610+ "http_headers" : {
611+ "Authorization" : "Bearer old-token" ,
612+ }
613+ }
614+ },
615+ "federatedAuthParams" : {
616+ "integrationId" : "test-integration-id" ,
617+ "authContextToken" : "test-auth-context-token" ,
618+ },
619+ }
620+
621+ # Call the function
622+ _handle_federated_auth_params (sql_alchemy_dict )
623+
624+ # Verify the API was called with correct params
625+ mock_get_credentials .assert_called_once_with (
626+ "test-integration-id" , "test-auth-context-token"
627+ )
628+
629+ # Verify the Authorization header was updated with the new token
630+ self .assertEqual (
631+ sql_alchemy_dict ["params" ]["connect_args" ]["http_headers" ]["Authorization" ],
632+ "Bearer test-trino-access-token" ,
633+ )
634+
635+ @mock .patch ("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials" )
636+ def test_federated_auth_params_bigquery (self , mock_get_credentials ):
637+ """Test that BigQuery federated auth updates the access_token in params."""
638+ from deepnote_toolkit .sql .sql_execution import (
639+ FederatedAuthResponseData ,
640+ _handle_federated_auth_params ,
641+ )
642+
643+ # Setup mock to return BigQuery credentials
644+ mock_get_credentials .return_value = FederatedAuthResponseData (
645+ integrationType = "big-query" ,
646+ accessToken = "test-bigquery-access-token" ,
647+ )
648+
649+ # Create a sql_alchemy_dict with federatedAuthParams
650+ sql_alchemy_dict = {
651+ "url" : "bigquery://?user_supplied_client=true" ,
652+ "params" : {
653+ "access_token" : "old-access-token" ,
654+ "project" : "test-project" ,
655+ },
656+ "federatedAuthParams" : {
657+ "integrationId" : "test-bigquery-integration-id" ,
658+ "authContextToken" : "test-bigquery-auth-context-token" ,
659+ },
660+ }
661+
662+ # Call the function
663+ _handle_federated_auth_params (sql_alchemy_dict )
664+
665+ # Verify the API was called with correct params
666+ mock_get_credentials .assert_called_once_with (
667+ "test-bigquery-integration-id" , "test-bigquery-auth-context-token"
668+ )
669+
670+ # Verify the access_token was updated with the new token
671+ self .assertEqual (
672+ sql_alchemy_dict ["params" ]["access_token" ],
673+ "test-bigquery-access-token" ,
674+ )
675+
676+ @mock .patch ("deepnote_toolkit.sql.sql_execution.logger" )
677+ @mock .patch ("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials" )
678+ def test_federated_auth_params_snowflake (self , mock_get_credentials , mock_logger ):
679+ """Test that Snowflake federated auth logs a warning since it's not supported yet."""
680+ from deepnote_toolkit .sql .sql_execution import (
681+ FederatedAuthResponseData ,
682+ _handle_federated_auth_params ,
683+ )
684+
685+ # Setup mock to return Snowflake credentials
686+ mock_get_credentials .return_value = FederatedAuthResponseData (
687+ integrationType = "snowflake" ,
688+ accessToken = "test-snowflake-access-token" ,
689+ )
690+
691+ # Create a sql_alchemy_dict with federatedAuthParams
692+ sql_alchemy_dict = {
693+ "url" : "snowflake://test@test?warehouse=&role=&application=Deepnote_Workspaces" ,
694+ "params" : {},
695+ "federatedAuthParams" : {
696+ "integrationId" : "test-snowflake-integration-id" ,
697+ "authContextToken" : "test-snowflake-auth-context-token" ,
698+ },
699+ }
700+
701+ # Store original params to verify they remain unchanged
702+ original_params = sql_alchemy_dict ["params" ].copy ()
703+
704+ # Call the function
705+ _handle_federated_auth_params (sql_alchemy_dict )
706+
707+ # Verify the API was called with correct params
708+ mock_get_credentials .assert_called_once_with (
709+ "test-snowflake-integration-id" , "test-snowflake-auth-context-token"
710+ )
711+
712+ # Verify a warning was logged
713+ mock_logger .warning .assert_called_once_with (
714+ "Snowflake federated auth is not supported yet, using the original connection URL"
715+ )
716+
717+ # Verify params were NOT modified (snowflake is not supported yet)
718+ self .assertEqual (sql_alchemy_dict ["params" ], original_params )
719+
720+ def test_federated_auth_params_not_present (self ):
721+ """Test that no action is taken when federatedAuthParams is not present."""
722+ from deepnote_toolkit .sql .sql_execution import _handle_federated_auth_params
723+
724+ # Create a sql_alchemy_dict without federatedAuthParams
725+ sql_alchemy_dict = {
726+ "url" : "trino://user@localhost:8080/catalog" ,
727+ "params" : {
728+ "connect_args" : {
729+ "http_headers" : {"Authorization" : "Bearer original-token" }
730+ }
731+ },
732+ }
733+
734+ original_dict = json .loads (json .dumps (sql_alchemy_dict ))
735+
736+ # Call the function
737+ _handle_federated_auth_params (sql_alchemy_dict )
738+
739+ # Verify the dict was not modified
740+ self .assertEqual (sql_alchemy_dict , original_dict )
741+
742+ @mock .patch ("deepnote_toolkit.sql.sql_execution.logger" )
743+ def test_federated_auth_params_invalid_params (self , mock_logger ):
744+ """Test that invalid federated auth params logs an error and returns early."""
745+ from deepnote_toolkit .sql .sql_execution import _handle_federated_auth_params
746+
747+ # Create a sql_alchemy_dict with invalid federatedAuthParams (missing required fields)
748+ sql_alchemy_dict = {
749+ "url" : "trino://user@localhost:8080/catalog" ,
750+ "params" : {},
751+ "federatedAuthParams" : {
752+ "invalidField" : "value" ,
753+ },
754+ }
755+
756+ original_dict = json .loads (json .dumps (sql_alchemy_dict ))
757+
758+ # Call the function
759+ _handle_federated_auth_params (sql_alchemy_dict )
760+
761+ # Verify an error was logged
762+ mock_logger .error .assert_called_once ()
763+ call_args = mock_logger .error .call_args
764+ self .assertIn ("Invalid federated auth params" , call_args [0 ][0 ])
765+
766+ self .assertEqual (sql_alchemy_dict , original_dict )
767+
768+ @mock .patch ("deepnote_toolkit.sql.sql_execution.logger" )
769+ @mock .patch ("deepnote_toolkit.sql.sql_execution._get_federated_auth_credentials" )
770+ def test_federated_auth_params_unsupported_integration_type (
771+ self , mock_get_credentials , mock_logger
772+ ):
773+ """Test that unsupported integration type logs an error."""
774+ from deepnote_toolkit .sql .sql_execution import (
775+ FederatedAuthResponseData ,
776+ _handle_federated_auth_params ,
777+ )
778+
779+ # Setup mock to return unknown integration type
780+ mock_get_credentials .return_value = FederatedAuthResponseData (
781+ integrationType = "unknown-database" ,
782+ accessToken = "test-token" ,
783+ )
784+
785+ # Create a sql_alchemy_dict with federatedAuthParams
786+ sql_alchemy_dict = {
787+ "url" : "unknown://host/db" ,
788+ "params" : {},
789+ "federatedAuthParams" : {
790+ "integrationId" : "test-integration-id" ,
791+ "authContextToken" : "test-auth-context-token" ,
792+ },
793+ }
794+
795+ original_dict = json .loads (json .dumps (sql_alchemy_dict ))
796+
797+ # Call the function
798+ _handle_federated_auth_params (sql_alchemy_dict )
799+
800+ # Verify an error was logged for unsupported integration type
801+ mock_logger .error .assert_called_once_with (
802+ "Unsupported integration type: %s, try updating toolkit version" ,
803+ "unknown-database" ,
804+ )
805+
806+ self .assertEqual (sql_alchemy_dict , original_dict )
0 commit comments