diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 73265015836b6..d94c472d30285 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -752,6 +752,7 @@ repos: ^.*/conf_constants\.py$| ^.*/provider_conf\.py$| ^devel-common/src/sphinx_exts/removemarktransform\.py| + ^devel-common/src/tests_common/test_utils/db\.py| ^airflow-core/newsfragments/41761.significant\.rst$| ^scripts/ci/pre_commit/vendor_k8s_json_schema\.py$| ^scripts/ci/docker-compose/integration-keycloak\.yml$| diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index 8537f50dd79bc..05315110e404d 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -125,9 +125,16 @@ def add_default_pool_if_not_exists(session: Session = NEW_SESSION): @provide_session def create_default_connections(session: Session = NEW_SESSION): """Create default Airflow connections.""" + conns = get_default_connections() + + for c in conns: + merge_conn(c, session) + + +def get_default_connections(): from airflow.models.connection import Connection - merge_conn( + conns = [ Connection( conn_id="airflow_db", conn_type="mysql", @@ -136,40 +143,26 @@ def create_default_connections(session: Session = NEW_SESSION): password="", schema="airflow", ), - session, - ) - merge_conn( Connection( conn_id="athena_default", conn_type="athena", ), - session, - ) - merge_conn( Connection( conn_id="aws_default", conn_type="aws", ), - session, - ) - merge_conn( Connection( conn_id="azure_batch_default", conn_type="azure_batch", login="", password="", extra="""{"account_url": ""}""", - ) - ) - merge_conn( + ), Connection( conn_id="azure_cosmos_default", conn_type="azure_cosmos", extra='{"database_name": "", "collection_name": "" }', ), - session, - ) - merge_conn( Connection( conn_id="azure_data_explorer_default", conn_type="azure_data_explorer", @@ -178,50 +171,32 @@ def create_default_connections(session: Session = NEW_SESSION): "tenant": "", "certificate": "", "thumbprint": ""}""", ), - session, - ) - merge_conn( Connection( conn_id="azure_data_lake_default", conn_type="azure_data_lake", extra='{"tenant": "", "account_name": "" }', ), - session, - ) - merge_conn( Connection( conn_id="azure_default", conn_type="azure", ), - session, - ) - merge_conn( Connection( conn_id="cassandra_default", conn_type="cassandra", host="cassandra", port=9042, ), - session, - ) - merge_conn( Connection( conn_id="databricks_default", conn_type="databricks", host="localhost", ), - session, - ) - merge_conn( Connection( conn_id="dingding_default", conn_type="http", host="", password="", ), - session, - ) - merge_conn( Connection( conn_id="drill_default", conn_type="drill", @@ -229,9 +204,6 @@ def create_default_connections(session: Session = NEW_SESSION): port=8047, extra='{"dialect_driver": "drill+sadrill", "storage_plugin": "dfs"}', ), - session, - ) - merge_conn( Connection( conn_id="druid_broker_default", conn_type="druid", @@ -239,9 +211,6 @@ def create_default_connections(session: Session = NEW_SESSION): port=8082, extra='{"endpoint": "druid/v2/sql"}', ), - session, - ) - merge_conn( Connection( conn_id="druid_ingest_default", conn_type="druid", @@ -249,9 +218,6 @@ def create_default_connections(session: Session = NEW_SESSION): port=8081, extra='{"endpoint": "druid/indexer/v1/task"}', ), - session, - ) - merge_conn( Connection( conn_id="elasticsearch_default", conn_type="elasticsearch", @@ -259,9 +225,6 @@ def create_default_connections(session: Session = NEW_SESSION): schema="http", port=9200, ), - session, - ) - merge_conn( Connection( conn_id="emr_default", conn_type="emr", @@ -310,9 +273,6 @@ def create_default_connections(session: Session = NEW_SESSION): } """, ), - session, - ) - merge_conn( Connection( conn_id="facebook_default", conn_type="facebook_social", @@ -324,17 +284,11 @@ def create_default_connections(session: Session = NEW_SESSION): } """, ), - session, - ) - merge_conn( Connection( conn_id="fs_default", conn_type="fs", extra='{"path": "/"}', ), - session, - ) - merge_conn( Connection( conn_id="ftp_default", conn_type="ftp", @@ -344,26 +298,17 @@ def create_default_connections(session: Session = NEW_SESSION): password="airflow", extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}', ), - session, - ) - merge_conn( Connection( conn_id="google_cloud_default", conn_type="google_cloud_platform", schema="default", ), - session, - ) - merge_conn( Connection( conn_id="gremlin_default", conn_type="gremlin", host="gremlin", port=8182, ), - session, - ) - merge_conn( Connection( conn_id="hive_cli_default", conn_type="hive_cli", @@ -372,9 +317,6 @@ def create_default_connections(session: Session = NEW_SESSION): extra='{"use_beeline": true, "auth": ""}', schema="default", ), - session, - ) - merge_conn( Connection( conn_id="hiveserver2_default", conn_type="hiveserver2", @@ -382,41 +324,26 @@ def create_default_connections(session: Session = NEW_SESSION): schema="default", port=10000, ), - session, - ) - merge_conn( Connection( conn_id="http_default", conn_type="http", host="https://www.httpbin.org/", ), - session, - ) - merge_conn( Connection( conn_id="iceberg_default", conn_type="iceberg", host="https://api.iceberg.io/ws/v1", ), - session, - ) - merge_conn(Connection(conn_id="impala_default", conn_type="impala", host="localhost", port=21050)) - merge_conn( + Connection(conn_id="impala_default", conn_type="impala", host="localhost", port=21050), Connection( conn_id="kafka_default", conn_type="kafka", extra=json.dumps({"bootstrap.servers": "broker:29092", "group.id": "my-group"}), ), - session, - ) - merge_conn( Connection( conn_id="kubernetes_default", conn_type="kubernetes", ), - session, - ) - merge_conn( Connection( conn_id="kylin_default", conn_type="kylin", @@ -425,18 +352,12 @@ def create_default_connections(session: Session = NEW_SESSION): login="ADMIN", password="KYLIN", ), - session, - ) - merge_conn( Connection( conn_id="leveldb_default", conn_type="leveldb", host="localhost", ), - session, - ) - merge_conn(Connection(conn_id="livy_default", conn_type="livy", host="livy", port=8998), session) - merge_conn( + Connection(conn_id="livy_default", conn_type="livy", host="livy", port=8998), Connection( conn_id="local_mysql", conn_type="mysql", @@ -445,9 +366,6 @@ def create_default_connections(session: Session = NEW_SESSION): password="airflow", schema="airflow", ), - session, - ) - merge_conn( Connection( conn_id="metastore_default", conn_type="hive_metastore", @@ -455,19 +373,13 @@ def create_default_connections(session: Session = NEW_SESSION): extra='{"authMechanism": "PLAIN"}', port=9083, ), - session, - ) - merge_conn(Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017), session) - merge_conn( + Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017), Connection( conn_id="mssql_default", conn_type="mssql", host="localhost", port=1433, ), - session, - ) - merge_conn( Connection( conn_id="mysql_default", conn_type="mysql", @@ -475,9 +387,6 @@ def create_default_connections(session: Session = NEW_SESSION): schema="airflow", host="mysql", ), - session, - ) - merge_conn( Connection( conn_id="opensearch_default", conn_type="opensearch", @@ -485,18 +394,12 @@ def create_default_connections(session: Session = NEW_SESSION): schema="http", port=9200, ), - session, - ) - merge_conn( Connection( conn_id="opsgenie_default", conn_type="http", host="", password="", ), - session, - ) - merge_conn( Connection( conn_id="oracle_default", conn_type="oracle", @@ -506,39 +409,28 @@ def create_default_connections(session: Session = NEW_SESSION): schema="schema", port=1521, ), - session, - ) - merge_conn( Connection( conn_id="oss_default", conn_type="oss", - extra="""{ + extra=""" + { "auth_type": "AK", "access_key_id": "", "access_key_secret": "", "region": ""} """, ), - session, - ) - merge_conn( Connection( conn_id="pig_cli_default", conn_type="pig_cli", schema="default", ), - session, - ) - merge_conn( Connection( conn_id="pinot_admin_default", conn_type="pinot", host="localhost", port=9000, ), - session, - ) - merge_conn( Connection( conn_id="pinot_broker_default", conn_type="pinot", @@ -546,9 +438,6 @@ def create_default_connections(session: Session = NEW_SESSION): port=9000, extra='{"endpoint": "/query", "schema": "http"}', ), - session, - ) - merge_conn( Connection( conn_id="postgres_default", conn_type="postgres", @@ -557,9 +446,6 @@ def create_default_connections(session: Session = NEW_SESSION): schema="airflow", host="postgres", ), - session, - ) - merge_conn( Connection( conn_id="presto_default", conn_type="presto", @@ -567,18 +453,12 @@ def create_default_connections(session: Session = NEW_SESSION): schema="hive", port=3400, ), - session, - ) - merge_conn( Connection( conn_id="qdrant_default", conn_type="qdrant", host="qdrant", port=6333, ), - session, - ) - merge_conn( Connection( conn_id="redis_default", conn_type="redis", @@ -586,13 +466,11 @@ def create_default_connections(session: Session = NEW_SESSION): port=6379, extra='{"db": 0}', ), - session, - ) - merge_conn( Connection( conn_id="redshift_default", conn_type="redshift", - extra="""{ + extra=""" +{ "iam": true, "cluster_identifier": "", "port": 5439, @@ -602,9 +480,6 @@ def create_default_connections(session: Session = NEW_SESSION): "region": "" }""", ), - session, - ) - merge_conn( Connection( conn_id="salesforce_default", conn_type="salesforce", @@ -612,17 +487,11 @@ def create_default_connections(session: Session = NEW_SESSION): password="password", extra='{"security_token": "security_token"}', ), - session, - ) - merge_conn( Connection( conn_id="segment_default", conn_type="segment", extra='{"write_key": "my-segment-write-key"}', ), - session, - ) - merge_conn( Connection( conn_id="sftp_default", conn_type="sftp", @@ -631,34 +500,22 @@ def create_default_connections(session: Session = NEW_SESSION): login="airflow", extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}', ), - session, - ) - merge_conn( Connection( conn_id="spark_default", conn_type="spark", host="yarn", extra='{"queue": "root.default"}', ), - session, - ) - merge_conn( Connection( conn_id="sqlite_default", conn_type="sqlite", host=os.path.join(gettempdir(), "sqlite_default.db"), ), - session, - ) - merge_conn( Connection( conn_id="ssh_default", conn_type="ssh", host="localhost", ), - session, - ) - merge_conn( Connection( conn_id="tableau_default", conn_type="tableau", @@ -667,9 +524,6 @@ def create_default_connections(session: Session = NEW_SESSION): password="password", extra='{"site_id": "my_site"}', ), - session, - ) - merge_conn( Connection( conn_id="teradata_default", conn_type="teradata", @@ -678,9 +532,6 @@ def create_default_connections(session: Session = NEW_SESSION): password="password", schema="schema", ), - session, - ) - merge_conn( Connection( conn_id="trino_default", conn_type="trino", @@ -688,43 +539,28 @@ def create_default_connections(session: Session = NEW_SESSION): schema="hive", port=3400, ), - session, - ) - merge_conn( Connection( conn_id="vertica_default", conn_type="vertica", host="localhost", port=5433, ), - session, - ) - merge_conn( Connection( conn_id="wasb_default", conn_type="wasb", extra='{"sas_token": null}', ), - session, - ) - merge_conn( Connection( conn_id="webhdfs_default", conn_type="hdfs", host="localhost", port=50070, ), - session, - ) - merge_conn( Connection( conn_id="yandexcloud_default", conn_type="yandexcloud", schema="default", ), - session, - ) - merge_conn( Connection( conn_id="ydb_default", conn_type="ydb", @@ -732,8 +568,8 @@ def create_default_connections(session: Session = NEW_SESSION): port=2135, extra={"database": "/local"}, ), - session, - ) + ] + return conns def _create_db_from_orm(session): diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py index a4184e29e8da3..6c56232a68a92 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py @@ -27,7 +27,7 @@ from airflow.utils.session import provide_session from tests_common.test_utils.api_fastapi import _check_last_log -from tests_common.test_utils.db import clear_db_connections, clear_db_logs +from tests_common.test_utils.db import clear_db_connections, clear_db_logs, clear_test_connections from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker pytestmark = pytest.mark.db_test @@ -84,6 +84,7 @@ def _create_connections(session) -> None: class TestConnectionEndpoint: @pytest.fixture(autouse=True) def setup(self) -> None: + clear_test_connections(False) clear_db_connections(False) clear_db_logs() diff --git a/devel-common/src/tests_common/test_utils/db.py b/devel-common/src/tests_common/test_utils/db.py index c5711ec514c05..30e5b3c596331 100644 --- a/devel-common/src/tests_common/test_utils/db.py +++ b/devel-common/src/tests_common/test_utils/db.py @@ -17,6 +17,8 @@ # under the License. from __future__ import annotations +import json +from tempfile import gettempdir from typing import TYPE_CHECKING from airflow.configuration import conf @@ -40,7 +42,11 @@ from airflow.models.dagwarning import DagWarning from airflow.models.serialized_dag import SerializedDagModel from airflow.security.permissions import RESOURCE_DAG_PREFIX -from airflow.utils.db import add_default_pool_if_not_exists, create_default_connections, reflect_tables +from airflow.utils.db import ( + add_default_pool_if_not_exists, + create_default_connections, + reflect_tables, +) from airflow.utils.session import create_session from tests_common.test_utils.compat import ( @@ -223,6 +229,18 @@ def clear_db_pools(): add_default_pool_if_not_exists(session) +def clear_test_connections(add_default_connections_back=True): + # clear environment variables with AIRFLOW_CONN prefix + import os + + env_vars_to_remove = [key for key in os.environ.keys() if key.startswith("AIRFLOW_CONN_")] + for env_var in env_vars_to_remove: + del os.environ[env_var] + + if add_default_connections_back: + create_default_connections_for_tests() + + def clear_db_connections(add_default_connections_back=True): with create_session() as session: session.query(Connection).delete() @@ -338,6 +356,461 @@ def clear_dag_specific_permissions(): session.query(Resource).filter(Resource.id.in_(dag_resource_ids)).delete(synchronize_session=False) +def create_default_connections_for_tests(): + """ + Create default Airflow connections for tests. + + For testing purposes, we do not need to have the connections setup in the database, using environment + variables instead would provide better lookup speeds and is easier too. + """ + import os + + try: + from airflow.utils.db import get_default_connections + + conns = get_default_connections() + except ImportError: + conns = [ + Connection( + conn_id="airflow_db", + conn_type="mysql", + host="mysql", + login="root", + password="", + schema="airflow", + ), + Connection( + conn_id="athena_default", + conn_type="athena", + ), + Connection( + conn_id="aws_default", + conn_type="aws", + ), + Connection( + conn_id="azure_batch_default", + conn_type="azure_batch", + login="", + password="", + extra="""{"account_url": ""}""", + ), + Connection( + conn_id="azure_cosmos_default", + conn_type="azure_cosmos", + extra='{"database_name": "", "collection_name": "" }', + ), + Connection( + conn_id="azure_data_explorer_default", + conn_type="azure_data_explorer", + host="https://.kusto.windows.net", + extra="""{"auth_method": "", + "tenant": "", "certificate": "", + "thumbprint": ""}""", + ), + Connection( + conn_id="azure_data_lake_default", + conn_type="azure_data_lake", + extra='{"tenant": "", "account_name": "" }', + ), + Connection( + conn_id="azure_default", + conn_type="azure", + ), + Connection( + conn_id="cassandra_default", + conn_type="cassandra", + host="cassandra", + port=9042, + ), + Connection( + conn_id="databricks_default", + conn_type="databricks", + host="localhost", + ), + Connection( + conn_id="dingding_default", + conn_type="http", + host="", + password="", + ), + Connection( + conn_id="drill_default", + conn_type="drill", + host="localhost", + port=8047, + extra='{"dialect_driver": "drill+sadrill", "storage_plugin": "dfs"}', + ), + Connection( + conn_id="druid_broker_default", + conn_type="druid", + host="druid-broker", + port=8082, + extra='{"endpoint": "druid/v2/sql"}', + ), + Connection( + conn_id="druid_ingest_default", + conn_type="druid", + host="druid-overlord", + port=8081, + extra='{"endpoint": "druid/indexer/v1/task"}', + ), + Connection( + conn_id="elasticsearch_default", + conn_type="elasticsearch", + host="localhost", + schema="http", + port=9200, + ), + Connection( + conn_id="emr_default", + conn_type="emr", + extra=""" + { "Name": "default_job_flow_name", + "LogUri": "s3://my-emr-log-bucket/default_job_flow_location", + "ReleaseLabel": "emr-4.6.0", + "Instances": { + "Ec2KeyName": "mykey", + "Ec2SubnetId": "somesubnet", + "InstanceGroups": [ + { + "Name": "Master nodes", + "Market": "ON_DEMAND", + "InstanceRole": "MASTER", + "InstanceType": "r3.2xlarge", + "InstanceCount": 1 + }, + { + "Name": "Core nodes", + "Market": "ON_DEMAND", + "InstanceRole": "CORE", + "InstanceType": "r3.2xlarge", + "InstanceCount": 1 + } + ], + "TerminationProtected": false, + "KeepJobFlowAliveWhenNoSteps": false + }, + "Applications":[ + { "Name": "Spark" } + ], + "VisibleToAllUsers": true, + "JobFlowRole": "EMR_EC2_DefaultRole", + "ServiceRole": "EMR_DefaultRole", + "Tags": [ + { + "Key": "app", + "Value": "analytics" + }, + { + "Key": "environment", + "Value": "development" + } + ] + } + """, + ), + Connection( + conn_id="facebook_default", + conn_type="facebook_social", + extra=""" + { "account_id": "", + "app_id": "", + "app_secret": "", + "access_token": "" + } + """, + ), + Connection( + conn_id="fs_default", + conn_type="fs", + extra='{"path": "/"}', + ), + Connection( + conn_id="ftp_default", + conn_type="ftp", + host="localhost", + port=21, + login="airflow", + password="airflow", + extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}', + ), + Connection( + conn_id="google_cloud_default", + conn_type="google_cloud_platform", + schema="default", + ), + Connection( + conn_id="gremlin_default", + conn_type="gremlin", + host="gremlin", + port=8182, + ), + Connection( + conn_id="hive_cli_default", + conn_type="hive_cli", + port=10000, + host="localhost", + extra='{"use_beeline": true, "auth": ""}', + schema="default", + ), + Connection( + conn_id="hiveserver2_default", + conn_type="hiveserver2", + host="localhost", + schema="default", + port=10000, + ), + Connection( + conn_id="http_default", + conn_type="http", + host="https://www.httpbin.org/", + ), + Connection( + conn_id="iceberg_default", + conn_type="iceberg", + host="https://api.iceberg.io/ws/v1", + ), + Connection(conn_id="impala_default", conn_type="impala", host="localhost", port=21050), + Connection( + conn_id="kafka_default", + conn_type="kafka", + extra=json.dumps({"bootstrap.servers": "broker:29092", "group.id": "my-group"}), + ), + Connection( + conn_id="kubernetes_default", + conn_type="kubernetes", + ), + Connection( + conn_id="kylin_default", + conn_type="kylin", + host="localhost", + port=7070, + login="ADMIN", + password="KYLIN", + ), + Connection( + conn_id="leveldb_default", + conn_type="leveldb", + host="localhost", + ), + Connection(conn_id="livy_default", conn_type="livy", host="livy", port=8998), + Connection( + conn_id="local_mysql", + conn_type="mysql", + host="localhost", + login="airflow", + password="airflow", + schema="airflow", + ), + Connection( + conn_id="metastore_default", + conn_type="hive_metastore", + host="localhost", + extra='{"authMechanism": "PLAIN"}', + port=9083, + ), + Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", port=27017), + Connection( + conn_id="mssql_default", + conn_type="mssql", + host="localhost", + port=1433, + ), + Connection( + conn_id="mysql_default", + conn_type="mysql", + login="root", + schema="airflow", + host="mysql", + ), + Connection( + conn_id="opensearch_default", + conn_type="opensearch", + host="localhost", + schema="http", + port=9200, + ), + Connection( + conn_id="opsgenie_default", + conn_type="http", + host="", + password="", + ), + Connection( + conn_id="oracle_default", + conn_type="oracle", + host="localhost", + login="root", + password="password", + schema="schema", + port=1521, + ), + Connection( + conn_id="oss_default", + conn_type="oss", + extra=""" + { + "auth_type": "AK", + "access_key_id": "", + "access_key_secret": "", + "region": ""} + """, + ), + Connection( + conn_id="pig_cli_default", + conn_type="pig_cli", + schema="default", + ), + Connection( + conn_id="pinot_admin_default", + conn_type="pinot", + host="localhost", + port=9000, + ), + Connection( + conn_id="pinot_broker_default", + conn_type="pinot", + host="localhost", + port=9000, + extra='{"endpoint": "/query", "schema": "http"}', + ), + Connection( + conn_id="postgres_default", + conn_type="postgres", + login="postgres", + password="airflow", + schema="airflow", + host="postgres", + ), + Connection( + conn_id="presto_default", + conn_type="presto", + host="localhost", + schema="hive", + port=3400, + ), + Connection( + conn_id="qdrant_default", + conn_type="qdrant", + host="qdrant", + port=6333, + ), + Connection( + conn_id="redis_default", + conn_type="redis", + host="redis", + port=6379, + extra='{"db": 0}', + ), + Connection( + conn_id="redshift_default", + conn_type="redshift", + extra=""" +{ + "iam": true, + "cluster_identifier": "", + "port": 5439, + "profile": "default", + "db_user": "awsuser", + "database": "dev", + "region": "" +}""", + ), + Connection( + conn_id="salesforce_default", + conn_type="salesforce", + login="username", + password="password", + extra='{"security_token": "security_token"}', + ), + Connection( + conn_id="segment_default", + conn_type="segment", + extra='{"write_key": "my-segment-write-key"}', + ), + Connection( + conn_id="sftp_default", + conn_type="sftp", + host="localhost", + port=22, + login="airflow", + extra='{"key_file": "~/.ssh/id_rsa", "no_host_key_check": true}', + ), + Connection( + conn_id="spark_default", + conn_type="spark", + host="yarn", + extra='{"queue": "root.default"}', + ), + Connection( + conn_id="sqlite_default", + conn_type="sqlite", + host=os.path.join(gettempdir(), "sqlite_default.db"), + ), + Connection( + conn_id="ssh_default", + conn_type="ssh", + host="localhost", + ), + Connection( + conn_id="tableau_default", + conn_type="tableau", + host="https://tableau.server.url", + login="user", + password="password", + extra='{"site_id": "my_site"}', + ), + Connection( + conn_id="teradata_default", + conn_type="teradata", + host="localhost", + login="user", + password="password", + schema="schema", + ), + Connection( + conn_id="trino_default", + conn_type="trino", + host="localhost", + schema="hive", + port=3400, + ), + Connection( + conn_id="vertica_default", + conn_type="vertica", + host="localhost", + port=5433, + ), + Connection( + conn_id="wasb_default", + conn_type="wasb", + extra='{"sas_token": null}', + ), + Connection( + conn_id="webhdfs_default", + conn_type="hdfs", + host="localhost", + port=50070, + ), + Connection( + conn_id="yandexcloud_default", + conn_type="yandexcloud", + schema="default", + ), + Connection( + conn_id="ydb_default", + conn_type="ydb", + host="grpc://localhost", + port=2135, + extra={"database": "/local"}, + ), + ] + + for c in conns: + envvar = f"AIRFLOW_CONN_{c.conn_id.upper()}" + os.environ[envvar] = c.as_json() + + def clear_all(): clear_db_runs() clear_db_assets() @@ -355,7 +828,7 @@ def clear_all(): clear_db_xcom() clear_db_variables() clear_db_pools() - clear_db_connections(add_default_connections_back=True) + clear_test_connections(add_default_connections_back=True) clear_db_deadline() clear_dag_specific_permissions() if AIRFLOW_V_3_0_PLUS: diff --git a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py index ab60c70f4e838..c794c42431eae 100644 --- a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py +++ b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py @@ -30,7 +30,7 @@ from airflow.models import Connection from airflow.providers.apache.livy.hooks.livy import BatchState, LivyAsyncHook, LivyHook -from tests_common.test_utils.db import clear_db_connections +from tests_common.test_utils.db import clear_test_connections LIVY_CONN_ID = LivyHook.default_conn_name DEFAULT_CONN_ID = LivyHook.default_conn_name @@ -56,11 +56,11 @@ class TestLivyDbHook: @classmethod def setup_class(cls): - clear_db_connections(add_default_connections_back=False) + clear_test_connections(add_default_connections_back=False) @classmethod def teardown_class(cls): - clear_db_connections(add_default_connections_back=True) + clear_test_connections(add_default_connections_back=True) # TODO: Potential performance issue, converted setup_class to a setup_connections function level fixture @pytest.fixture(autouse=True) diff --git a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_sql.py b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_sql.py index 301ccd730fa81..deeae5cb5e22d 100644 --- a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_sql.py +++ b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_sql.py @@ -27,7 +27,7 @@ from airflow.models import Connection from airflow.providers.apache.spark.hooks.spark_sql import SparkSqlHook -from tests_common.test_utils.db import clear_db_connections +from tests_common.test_utils.db import clear_test_connections def get_after(sentinel, iterable): @@ -64,7 +64,7 @@ class TestSparkSqlHook: @classmethod def setup_class(cls) -> None: - clear_db_connections(add_default_connections_back=False) + clear_test_connections(add_default_connections_back=False) @pytest.fixture(autouse=True) def setup_connections(self, create_connection_without_db): @@ -74,7 +74,7 @@ def setup_connections(self, create_connection_without_db): @classmethod def teardown_class(cls) -> None: - clear_db_connections(add_default_connections_back=True) + clear_test_connections(add_default_connections_back=True) @pytest.mark.db_test def test_build_command(self): diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py index ad1b07ebf4811..1b8db52452299 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/hooks/test_kubernetes.py @@ -31,14 +31,12 @@ from kubernetes.client import V1Deployment, V1DeploymentStatus from kubernetes.client.rest import ApiException from kubernetes.config import ConfigException -from sqlalchemy.orm import make_transient from airflow.exceptions import AirflowException, AirflowNotFoundException -from airflow.hooks.base import BaseHook from airflow.models import Connection from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook, KubernetesHook -from tests_common.test_utils.db import clear_db_connections +from tests_common.test_utils.db import clear_test_connections from tests_common.test_utils.providers import get_provider_min_airflow_version pytestmark = pytest.mark.db_test @@ -79,16 +77,18 @@ class DeprecationRemovalRequired(AirflowException): ... @pytest.fixture -def remove_default_conn(session): - before_conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).one_or_none() - if before_conn: - session.delete(before_conn) - session.commit() +def remove_default_conn(monkeypatch): + original_env_var = os.environ.get(f"AIRFLOW_CONN_{DEFAULT_CONN_ID.upper()}") + + # remove the env variable to simulate no default connection + if original_env_var: + monkeypatch.delenv(f"AIRFLOW_CONN_{DEFAULT_CONN_ID.upper()}") + yield - if before_conn: - make_transient(before_conn) - session.add(before_conn) - session.commit() + + # restore the original env variable + if original_env_var: + monkeypatch.setenv(f"AIRFLOW_CONN_{DEFAULT_CONN_ID.upper()}", original_env_var) class TestKubernetesHook: @@ -138,7 +138,7 @@ def setup_connections(self, create_connection_without_db): @classmethod def teardown_class(cls) -> None: - clear_db_connections() + clear_test_connections() @pytest.mark.parametrize( "in_cluster_param, conn_id, in_cluster_called", @@ -439,8 +439,8 @@ def test_prefixed_names_still_work(self, mock_get_client): def test_missing_default_connection_is_ok(self, remove_default_conn): # prove to ourselves that the default conn doesn't exist - with pytest.raises(AirflowNotFoundException): - BaseHook.get_connection(DEFAULT_CONN_ID) + k8s_conn_exists = os.environ.get(f"AIRFLOW_CONN_{DEFAULT_CONN_ID.upper()}") + assert k8s_conn_exists is None # verify K8sHook still works hook = KubernetesHook() @@ -849,7 +849,7 @@ def kubernetes_connection(create_connection_without_db): ), ) yield - clear_db_connections() + clear_test_connections() @pytest.mark.asyncio @mock.patch(INCLUSTER_CONFIG_LOADER) @@ -936,7 +936,7 @@ async def test_load_config_with_conn_id_kube_config_path( except: raise finally: - clear_db_connections() + clear_test_connections() @pytest.mark.asyncio @mock.patch(INCLUSTER_CONFIG_LOADER) diff --git a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler_system.py b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler_system.py index e88c6fbeb3d14..c3361cf2fba56 100644 --- a/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler_system.py +++ b/providers/google/tests/unit/google/cloud/log/test_gcs_task_handler_system.py @@ -31,7 +31,7 @@ from airflow.utils.session import provide_session from tests_common.test_utils.config import conf_vars -from tests_common.test_utils.db import clear_db_connections, clear_db_runs +from tests_common.test_utils.db import clear_db_runs, clear_test_connections from tests_common.test_utils.gcp_system_helpers import ( GoogleSystemTest, provide_gcp_context, @@ -48,7 +48,7 @@ def setup_class(cls) -> None: unique_suffix = "".join(random.sample(string.ascii_lowercase, 16)) cls.bucket_name = f"airflow-gcs-task-handler-tests-{unique_suffix}" # type: ignore cls.create_gcs_bucket(cls.bucket_name) # type: ignore - clear_db_connections() + clear_test_connections() @classmethod def teardown_class(cls) -> None: diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataprep_system.py b/providers/google/tests/unit/google/cloud/operators/test_dataprep_system.py index f27aa92684fff..eca2a1493f5cd 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataprep_system.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataprep_system.py @@ -25,7 +25,7 @@ from airflow.models import Connection from airflow.utils.session import create_session -from tests_common.test_utils.db import clear_db_connections +from tests_common.test_utils.db import clear_test_connections from tests_common.test_utils.gcp_system_helpers import GoogleSystemTest from tests_common.test_utils.system_tests import get_test_run @@ -50,7 +50,7 @@ def setup_method(self): session.add(dataprep_conn_id) def teardown_method(self): - clear_db_connections() + clear_test_connections() def test_run_example_dag(self): from unit.google.cloud.dataprep.example_dataprep import dag diff --git a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py index 43f8208dbc640..bfd83b81e97c0 100644 --- a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py +++ b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py @@ -34,9 +34,6 @@ from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.sftp.hooks.sftp import SFTPHook, SFTPHookAsync -from airflow.utils.session import provide_session - -pytestmark = pytest.mark.db_test def generate_host_key(pkey: paramiko.PKey): @@ -62,13 +59,31 @@ def generate_host_key(pkey: paramiko.PKey): class TestSFTPHook: - @provide_session - def update_connection(self, login, session=None): - connection = session.query(Connection).filter(Connection.conn_id == "sftp_default").first() - old_login = connection.login - connection.login = login - connection.extra = "" # clear out extra so it doesn't look for a key file - session.commit() + def update_connection(self, login): + import os + + # Get the current connection from environment variable to find the old login + old_connection = os.environ.get("AIRFLOW_CONN_SFTP_DEFAULT") + old_login = "airflow" # default fallback + + if old_connection: + try: + old_conn = Connection.from_json(old_connection) + old_login = old_conn.login + except Exception: + pass + + # Set the connection as an environment variable + new_connection = Connection( + conn_id="sftp_default", + conn_type="sftp", + host="localhost", + login=login, + password="airflow", + extra="", # clear out extra so it doesn't look for a key file + ) + os.environ[f"AIRFLOW_CONN_{new_connection.conn_id.upper()}"] = new_connection.as_json() + return old_login def _create_additional_test_file(self, file_name): @@ -549,7 +564,17 @@ def test_store_and_retrieve_directory_concurrently(self): @patch("paramiko.SSHClient") @patch("paramiko.ProxyCommand") - def test_sftp_hook_with_proxy_command(self, mock_proxy_command, mock_ssh_client): + @patch("airflow.providers.sftp.hooks.sftp.SFTPHook.get_connection") + def test_sftp_hook_with_proxy_command(self, mock_get_connection, mock_proxy_command, mock_ssh_client): + # Mock the connection to not have a password + mock_connection = MagicMock() + mock_connection.login = "user" + mock_connection.password = None + mock_connection.host = "example.com" + mock_connection.port = 22 + mock_connection.extra = None + mock_get_connection.return_value = mock_connection + mock_sftp_client = MagicMock(spec=SFTPClient) mock_ssh_client.open_sftp.return_value = mock_sftp_client