diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py index 7dce8fd4ea806..71d8f13b5fcb5 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/connections.py @@ -54,6 +54,7 @@ ) from airflow.api_fastapi.logging.decorators import action_logging from airflow.configuration import conf +from airflow.exceptions import AirflowNotFoundException from airflow.models import Connection from airflow.secrets.environment_variables import CONN_ENV_PREFIX from airflow.utils.db import create_default_connections as db_create_default_connections @@ -212,9 +213,7 @@ def patch_connection( @connections_router.post("/test", dependencies=[Depends(requires_access_connection(method="POST"))]) -def test_connection( - test_body: ConnectionBody, -) -> ConnectionTestResponse: +def test_connection(test_body: ConnectionBody) -> ConnectionTestResponse: """ Test an API connection. @@ -232,9 +231,17 @@ def test_connection( transient_conn_id = get_random_string() conn_env_var = f"{CONN_ENV_PREFIX}{transient_conn_id.upper()}" try: - data = test_body.model_dump(by_alias=True) - data["conn_id"] = transient_conn_id - conn = Connection(**data) + # Try to get existing connection and merge with provided values + try: + existing_conn = Connection.get_connection_from_secrets(test_body.connection_id) + existing_conn.conn_id = transient_conn_id + update_orm_from_pydantic(existing_conn, test_body) + conn = existing_conn + except AirflowNotFoundException: + data = test_body.model_dump(by_alias=True) + data["conn_id"] = transient_conn_id + conn = Connection(**data) + os.environ[conn_env_var] = conn.get_uri() test_status, test_message = conn.test_connection() return ConnectionTestResponse.model_validate({"status": test_status, "message": test_message}) diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py index d2c48f945d480..cccbcd367ca74 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py @@ -16,12 +16,13 @@ # under the License. from __future__ import annotations +import json import os from importlib.metadata import PackageNotFoundError, metadata from unittest import mock import pytest -from sqlalchemy import select +from sqlalchemy import func, select from sqlalchemy.orm import Session from airflow.models import Connection @@ -1001,6 +1002,140 @@ def test_should_respond_403_by_default(self, test_client, body): "Contact your deployment admin to enable it." } + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_should_merge_password_with_existing_connection(self, test_client, session): + connection = Connection( + conn_id=TEST_CONN_ID, + conn_type="sqlite", + password="existing_password", + ) + session.add(connection) + session.commit() + initial_count = session.scalar(select(func.count()).select_from(Connection)) + + captured_value = {} + + def mock_test_connection(self): + captured_value["password"] = self.password + captured_value["conn_type"] = self.conn_type + return True, "mocked" + + body = { + "connection_id": TEST_CONN_ID, + "conn_type": "new_sqlite", + "password": "***", + } + + with mock.patch.object(Connection, "test_connection", mock_test_connection): + response = test_client.post("/connections/test", json=body) + + assert response.status_code == 200 + assert response.json()["status"] is True + # Verify that the existing password was used, not "***" + assert captured_value["password"] == "existing_password" + # Verify that payload info were used for other fields + assert captured_value["conn_type"] == "new_sqlite" + + # Verify DB was not mutated + session.expire_all() + db_conn = session.scalar(select(Connection).filter_by(conn_id=TEST_CONN_ID)) + assert db_conn.password == "existing_password" + assert session.scalar(select(func.count()).select_from(Connection)) == initial_count + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_should_merge_extra_with_existing_connection(self, test_client, session): + connection = Connection( + conn_id=TEST_CONN_ID, + conn_type="fs", + extra='{"path": "/", "existing_key": "existing_value"}', + ) + session.add(connection) + session.commit() + initial_count = session.scalar(select(func.count()).select_from(Connection)) + + captured_extra = {} + + def mock_test_connection(self): + captured_extra["value"] = self.extra + return True, "mocked" + + body = { + "connection_id": TEST_CONN_ID, + "conn_type": "fs", + "extra": '{"path": "/", "new_key": "new_value"}', + } + + with mock.patch.object(Connection, "test_connection", mock_test_connection): + response = test_client.post("/connections/test", json=body) + + assert response.status_code == 200 + assert response.json()["status"] is True + # Verify that new_key is reflected in the merged extra + merged_extra = json.loads(captured_extra["value"]) + assert merged_extra["new_key"] == "new_value" + assert merged_extra["path"] == "/" + + # Verify DB was not mutated + session.expire_all() + db_conn = session.scalar(select(Connection).filter_by(conn_id=TEST_CONN_ID)) + assert json.loads(db_conn.extra) == {"path": "/", "existing_key": "existing_value"} + assert session.scalar(select(func.count()).select_from(Connection)) == initial_count + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_should_merge_both_password_and_extra(self, test_client, session): + connection = Connection( + conn_id=TEST_CONN_ID, + conn_type="fs", + password="existing_password", + extra='{"path": "/", "existing_key": "existing_value"}', + ) + session.add(connection) + session.commit() + initial_count = session.scalar(select(func.count()).select_from(Connection)) + + captured_values = {} + + def mock_test_connection(self): + captured_values["password"] = self.password + captured_values["extra"] = self.extra + return True, "mocked" + + body = { + "connection_id": TEST_CONN_ID, + "conn_type": "fs", + "password": "***", + "extra": '{"path": "/", "new_key": "new_value"}', + } + + with mock.patch.object(Connection, "test_connection", mock_test_connection): + response = test_client.post("/connections/test", json=body) + + assert response.status_code == 200 + assert response.json()["status"] is True + # Verify that the existing password was used, not "***" + assert captured_values["password"] == "existing_password" + # Verify that new_key is reflected in the merged extra + merged_extra = json.loads(captured_values["extra"]) + assert merged_extra["new_key"] == "new_value" + assert merged_extra["path"] == "/" + + # Verify DB was not mutated + session.expire_all() + db_conn = session.scalar(select(Connection).filter_by(conn_id=TEST_CONN_ID)) + assert db_conn.password == "existing_password" + assert json.loads(db_conn.extra) == {"path": "/", "existing_key": "existing_value"} + assert session.scalar(select(func.count()).select_from(Connection)) == initial_count + + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_should_test_new_connection_without_existing(self, test_client): + body = { + "connection_id": "non_existent_conn", + "conn_type": "sqlite", + } + response = test_client.post("/connections/test", json=body) + assert response.status_code == 200 + assert response.json()["status"] is True + class TestCreateDefaultConnections(TestConnectionEndpoint): def test_should_respond_204(self, test_client, session):