From 4ae85d754e9f8a65d461e86eb6111d3b9974a065 Mon Sep 17 00:00:00 2001 From: max <42827971+moiseenkov@users.noreply.github.com> Date: Tue, 23 Apr 2024 10:47:42 +0000 Subject: [PATCH] Bugfix BigQueryToMsSqlOperator (#39171) --- .../cloud/transfers/bigquery_to_mssql.py | 2 +- tests/always/test_project_structure.py | 1 - .../cloud/transfers/test_bigquery_to_mssql.py | 88 ++++++ .../bigquery/example_bigquery_to_mssql.py | 276 ++++++++++++++++-- 4 files changed, 342 insertions(+), 25 deletions(-) create mode 100644 tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py b/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py index c251ec5615aa1..8a5749dc9e2ce 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py @@ -91,7 +91,7 @@ def __init__( self.source_project_dataset_table = source_project_dataset_table def get_sql_hook(self) -> MsSqlHook: - return MsSqlHook(schema=self.database, mysql_conn_id=self.mssql_conn_id) + return MsSqlHook(schema=self.database, mssql_conn_id=self.mssql_conn_id) def persist_links(self, context: Context) -> None: project_id, dataset_id, table_id = self.source_project_dataset_table.split(".") diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index 4341dd1aec9f2..3437092e6569f 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -143,7 +143,6 @@ def test_providers_modules_should_have_tests(self): "tests/providers/google/cloud/operators/vertex_ai/test_model_service.py", "tests/providers/google/cloud/operators/vertex_ai/test_pipeline_job.py", "tests/providers/google/cloud/sensors/test_dataform.py", - "tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py", "tests/providers/google/cloud/transfers/test_bigquery_to_sql.py", "tests/providers/google/cloud/transfers/test_mssql_to_gcs.py", "tests/providers/google/cloud/transfers/test_presto_to_gcs.py", diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py b/tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py new file mode 100644 index 0000000000000..e4fd89732467c --- /dev/null +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py @@ -0,0 +1,88 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from airflow.providers.google.cloud.transfers.bigquery_to_mssql import BigQueryToMsSqlOperator + +TASK_ID = "test-bq-create-table-operator" +TEST_DATASET = "test-dataset" +TEST_TABLE_ID = "test-table-id" +TEST_DAG_ID = "test-bigquery-operators" +TEST_PROJECT = "test-project" + + +class TestBigQueryToMsSqlOperator: + @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryTableLink") + @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook") + def test_execute_good_request_to_bq(self, mock_hook, mock_link): + destination_table = "table" + operator = BigQueryToMsSqlOperator( + task_id=TASK_ID, + source_project_dataset_table=f"{TEST_PROJECT}.{TEST_DATASET}.{TEST_TABLE_ID}", + target_table_name=destination_table, + replace=False, + ) + + operator.execute(None) + mock_hook.return_value.list_rows.assert_called_once_with( + dataset_id=TEST_DATASET, + table_id=TEST_TABLE_ID, + max_results=1000, + selected_fields=None, + start_index=0, + ) + + @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.MsSqlHook") + def test_get_sql_hook(self, mock_hook): + hook_expected = mock_hook.return_value + + destination_table = "table" + operator = BigQueryToMsSqlOperator( + task_id=TASK_ID, + source_project_dataset_table=f"{TEST_PROJECT}.{TEST_DATASET}.{TEST_TABLE_ID}", + target_table_name=destination_table, + replace=False, + ) + + hook_actual = operator.get_sql_hook() + + assert hook_actual == hook_expected + mock_hook.assert_called_once_with(schema=operator.database, mssql_conn_id=operator.mssql_conn_id) + + @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryTableLink") + def test_persist_links(self, mock_link): + mock_context = mock.MagicMock() + + destination_table = "table" + operator = BigQueryToMsSqlOperator( + task_id=TASK_ID, + source_project_dataset_table=f"{TEST_PROJECT}.{TEST_DATASET}.{TEST_TABLE_ID}", + target_table_name=destination_table, + replace=False, + ) + operator.persist_links(context=mock_context) + + mock_link.persist.assert_called_once_with( + context=mock_context, + task_instance=operator, + dataset_id=TEST_DATASET, + project_id=TEST_PROJECT, + table_id=TEST_TABLE_ID, + ) diff --git a/tests/system/providers/google/cloud/bigquery/example_bigquery_to_mssql.py b/tests/system/providers/google/cloud/bigquery/example_bigquery_to_mssql.py index 822020df28b66..2ad7671ec3479 100644 --- a/tests/system/providers/google/cloud/bigquery/example_bigquery_to_mssql.py +++ b/tests/system/providers/google/cloud/bigquery/example_bigquery_to_mssql.py @@ -17,21 +17,42 @@ # under the License. """ Example Airflow DAG for Google BigQuery service. + +This DAG relies on the following OS environment variables + +* AIRFLOW__API__GOOGLE_KEY_PATH - Path to service account key file. Note, you can skip this variable if you + run this DAG in a Composer environment. """ from __future__ import annotations +import logging import os from datetime import datetime import pytest +from pendulum import duration +from airflow.decorators import task +from airflow.models import Connection from airflow.models.dag import DAG +from airflow.operators.bash import BashOperator +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator +from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook +from airflow.providers.google.cloud.hooks.compute_ssh import ComputeEngineSSHHook from airflow.providers.google.cloud.operators.bigquery import ( BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryDeleteDatasetOperator, + BigQueryInsertJobOperator, ) +from airflow.providers.google.cloud.operators.compute import ( + ComputeEngineDeleteInstanceOperator, + ComputeEngineInsertInstanceOperator, +) +from airflow.providers.ssh.operators.ssh import SSHOperator +from airflow.settings import Session +from airflow.utils.trigger_rule import TriggerRule try: from airflow.providers.google.cloud.transfers.bigquery_to_mssql import BigQueryToMsSqlOperator @@ -39,13 +60,102 @@ pytest.skip("MsSQL not available", allow_module_level=True) ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") -PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "example-project") DAG_ID = "example_bigquery_to_mssql" -DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}" -DATA_EXPORT_BUCKET_NAME = os.environ.get("GCP_BIGQUERY_EXPORT_BUCKET_NAME", "INVALID BUCKET NAME") -TABLE = "table_42" -destination_table = "mssql_table_test" + +REGION = "europe-west2" +ZONE = REGION + "-a" +NETWORK = "default" +CONNECTION_ID = f"connection_{DAG_ID}_{ENV_ID}".replace("-", "_") +CONNECTION_TYPE = "mssql" + +BIGQUERY_DATASET_NAME = f"dataset_{DAG_ID}_{ENV_ID}" +BIGQUERY_TABLE = "table_42" +INSERT_ROWS_QUERY = ( + f"INSERT INTO {BIGQUERY_DATASET_NAME}.{BIGQUERY_TABLE} (emp_name, salary) " + "VALUES ('emp 1', 10000), ('emp 2', 15000);" +) + +DB_PORT = 1433 +DB_USER_NAME = "sa" +DB_USER_PASSWORD = "5FHq4fSZ85kK6g0n" +SETUP_MSSQL_COMMAND = f""" +sudo apt update && +sudo apt install -y docker.io && +sudo docker run -e ACCEPT_EULA=Y -e MSSQL_SA_PASSWORD={DB_USER_PASSWORD} -p {DB_PORT}:{DB_PORT} \ + -d mcr.microsoft.com/mssql/server:2022-latest +""" +SQL_TABLE = "test_table" +SQL_CREATE_TABLE = f"""if not exists (select * from sys.tables where sys.tables.name='{SQL_TABLE}' and sys.tables.type='U') + create table {SQL_TABLE} ( + emp_name VARCHAR(8), + salary INT + ) +""" + +GCE_MACHINE_TYPE = "n1-standard-1" +GCE_INSTANCE_NAME = f"instance-{DAG_ID}-{ENV_ID}".replace("_", "-") +GCE_INSTANCE_BODY = { + "name": GCE_INSTANCE_NAME, + "machine_type": f"zones/{ZONE}/machineTypes/{GCE_MACHINE_TYPE}", + "disks": [ + { + "boot": True, + "device_name": GCE_INSTANCE_NAME, + "initialize_params": { + "disk_size_gb": "10", + "disk_type": f"zones/{ZONE}/diskTypes/pd-balanced", + "source_image": "projects/debian-cloud/global/images/debian-11-bullseye-v20220621", + }, + } + ], + "network_interfaces": [ + { + "access_configs": [{"name": "External NAT", "network_tier": "PREMIUM"}], + "stack_type": "IPV4_ONLY", + "subnetwork": f"regions/{REGION}/subnetworks/default", + } + ], +} +FIREWALL_RULE_NAME = f"allow-http-{DB_PORT}" +CREATE_FIREWALL_RULE_COMMAND = f""" +if [ $AIRFLOW__API__GOOGLE_KEY_PATH ]; then \ + gcloud auth activate-service-account --key-file=$AIRFLOW__API__GOOGLE_KEY_PATH; \ +fi; + +if [ -z gcloud compute firewall-rules list --filter=name:{FIREWALL_RULE_NAME} --format="value(name)" ]; then \ + gcloud compute firewall-rules create {FIREWALL_RULE_NAME} \ + --project={PROJECT_ID} \ + --direction=INGRESS \ + --priority=100 \ + --network={NETWORK} \ + --action=ALLOW \ + --rules=tcp:{DB_PORT} \ + --source-ranges=0.0.0.0/0 +else + echo "Firewall rule {FIREWALL_RULE_NAME} already exists." +fi +""" +DELETE_FIREWALL_RULE_COMMAND = f""" +if [ $AIRFLOW__API__GOOGLE_KEY_PATH ]; then \ + gcloud auth activate-service-account --key-file=$AIRFLOW__API__GOOGLE_KEY_PATH; \ +fi; \ +if [ gcloud compute firewall-rules list --filter=name:{FIREWALL_RULE_NAME} --format="value(name)" ]; then \ + gcloud compute firewall-rules delete {FIREWALL_RULE_NAME} --project={PROJECT_ID} --quiet; \ +fi; +""" +DELETE_PERSISTENT_DISK_COMMAND = f""" +if [ $AIRFLOW__API__GOOGLE_KEY_PATH ]; then \ + gcloud auth activate-service-account --key-file=$AIRFLOW__API__GOOGLE_KEY_PATH; \ +fi; + +gcloud compute disks delete {GCE_INSTANCE_NAME} --project={PROJECT_ID} --zone={ZONE} --quiet +""" + + +log = logging.getLogger(__name__) + with DAG( DAG_ID, @@ -54,41 +164,161 @@ catchup=False, tags=["example", "bigquery"], ) as dag: + create_bigquery_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create_bigquery_dataset", dataset_id=BIGQUERY_DATASET_NAME + ) + + create_bigquery_table = BigQueryCreateEmptyTableOperator( + task_id="create_bigquery_table", + dataset_id=BIGQUERY_DATASET_NAME, + table_id=BIGQUERY_TABLE, + schema_fields=[ + {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, + {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, + ], + ) + + insert_bigquery_data = BigQueryInsertJobOperator( + task_id="insert_bigquery_data", + configuration={ + "query": { + "query": INSERT_ROWS_QUERY, + "useLegacySql": False, + "priority": "BATCH", + } + }, + ) + + create_gce_instance = ComputeEngineInsertInstanceOperator( + task_id="create_gce_instance", + project_id=PROJECT_ID, + zone=ZONE, + body=GCE_INSTANCE_BODY, + ) + + create_firewall_rule = BashOperator( + task_id="create_firewall_rule", + bash_command=CREATE_FIREWALL_RULE_COMMAND, + ) + + setup_mssql = SSHOperator( + task_id="setup_mssql", + ssh_hook=ComputeEngineSSHHook( + user="username", + instance_name=GCE_INSTANCE_NAME, + zone=ZONE, + project_id=PROJECT_ID, + use_oslogin=False, + use_iap_tunnel=False, + cmd_timeout=180, + ), + command=SETUP_MSSQL_COMMAND, + retries=4, + ) + + @task + def get_public_ip() -> str: + hook = ComputeEngineHook() + address = hook.get_instance_address(resource_id=GCE_INSTANCE_NAME, zone=ZONE, project_id=PROJECT_ID) + return address + + get_public_ip_task = get_public_ip() + + @task + def setup_connection(ip_address: str) -> None: + connection = Connection( + conn_id=CONNECTION_ID, + description="Example connection", + conn_type=CONNECTION_TYPE, + host=ip_address, + login=DB_USER_NAME, + password=DB_USER_PASSWORD, + port=DB_PORT, + ) + session = Session() + log.info("Removing connection %s if it exists", CONNECTION_ID) + query = session.query(Connection).filter(Connection.conn_id == CONNECTION_ID) + query.delete() + + session.add(connection) + session.commit() + log.info("Connection %s created", CONNECTION_ID) + + setup_connection_task = setup_connection(get_public_ip_task) + + create_sql_table = SQLExecuteQueryOperator( + task_id="create_sql_table", + conn_id=CONNECTION_ID, + sql=SQL_CREATE_TABLE, + retries=4, + retry_delay=duration(seconds=20), + retry_exponential_backoff=False, + ) + # [START howto_operator_bigquery_to_mssql] bigquery_to_mssql = BigQueryToMsSqlOperator( task_id="bigquery_to_mssql", - source_project_dataset_table=f"{PROJECT_ID}.{DATASET_NAME}.{TABLE}", - target_table_name=destination_table, + mssql_conn_id=CONNECTION_ID, + source_project_dataset_table=f"{PROJECT_ID}.{BIGQUERY_DATASET_NAME}.{BIGQUERY_TABLE}", + target_table_name=SQL_TABLE, replace=False, ) # [END howto_operator_bigquery_to_mssql] - create_dataset = BigQueryCreateEmptyDatasetOperator(task_id="create_dataset", dataset_id=DATASET_NAME) + delete_bigquery_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_bigquery_dataset", + dataset_id=BIGQUERY_DATASET_NAME, + delete_contents=True, + trigger_rule=TriggerRule.ALL_DONE, + ) - create_table = BigQueryCreateEmptyTableOperator( - task_id="create_table", - dataset_id=DATASET_NAME, - table_id=TABLE, - schema_fields=[ - {"name": "emp_name", "type": "STRING", "mode": "REQUIRED"}, - {"name": "salary", "type": "INTEGER", "mode": "NULLABLE"}, - ], + delete_firewall_rule = BashOperator( + task_id="delete_firewall_rule", + bash_command=DELETE_FIREWALL_RULE_COMMAND, + trigger_rule=TriggerRule.ALL_DONE, ) - delete_dataset = BigQueryDeleteDatasetOperator( - task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True + delete_gce_instance = ComputeEngineDeleteInstanceOperator( + task_id="delete_gce_instance", + resource_id=GCE_INSTANCE_NAME, + zone=ZONE, + project_id=PROJECT_ID, + trigger_rule=TriggerRule.ALL_DONE, ) + delete_persistent_disk = BashOperator( + task_id="delete_persistent_disk", + bash_command=DELETE_PERSISTENT_DISK_COMMAND, + trigger_rule=TriggerRule.ALL_DONE, + ) + + delete_connection = BashOperator( + task_id="delete_connection", + bash_command=f"airflow connections delete {CONNECTION_ID}", + trigger_rule=TriggerRule.ALL_DONE, + ) + + # TEST SETUP + create_bigquery_dataset >> create_bigquery_table >> insert_bigquery_data + create_gce_instance >> setup_mssql + create_gce_instance >> get_public_ip_task >> setup_connection_task + [setup_mssql, setup_connection_task, create_firewall_rule] >> create_sql_table + ( - # TEST SETUP - create_dataset - >> create_table + [insert_bigquery_data, create_sql_table] # TEST BODY >> bigquery_to_mssql - # TEST TEARDOWN - >> delete_dataset ) + # TEST TEARDOWN + bigquery_to_mssql >> [ + delete_bigquery_dataset, + delete_firewall_rule, + delete_gce_instance, + delete_connection, + ] + delete_gce_instance >> delete_persistent_disk + from tests.system.utils.watcher import watcher # This test needs watcher in order to properly mark success/failure