From 934f6b9084e389ed06c8447b4e335d4c633d1375 Mon Sep 17 00:00:00 2001 From: "M. Olcay Tercanli" Date: Wed, 24 Apr 2024 07:57:35 +0000 Subject: [PATCH] Add encryption_configuration parameter to BigQueryCheckOperator and BigQueryTableCheckOperator --- .../google/cloud/operators/bigquery.py | 52 +++++++++++++++++-- .../google/cloud/operators/test_bigquery.py | 47 +++++++++++++++++ 2 files changed, 94 insertions(+), 5 deletions(-) diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 1cf0f9ee9a350..dcd971ad0b221 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -203,7 +203,25 @@ def get_openlineage_facets_on_complete(self, task_instance): ) -class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator): +class _BigQueryOperatorsEncryptionConfigurationMixin: + """A class to handle the configuration for BigQueryHook.insert_job method.""" + + # Note: If you want to add this feature to a new operator you can include the class name in the type + # annotation of the `self`. Then you can inherit this class in the target operator. + # e.g: BigQueryCheckOperator, BigQueryTableCheckOperator + def include_encryption_configuration( # type:ignore[misc] + self: BigQueryCheckOperator | BigQueryTableCheckOperator, + configuration: dict, + config_key: str, + ) -> None: + """Add encryption_configuration to destinationEncryptionConfiguration key if it is not None.""" + if self.encryption_configuration is not None: + configuration[config_key]["destinationEncryptionConfiguration"] = self.encryption_configuration + + +class BigQueryCheckOperator( + _BigQueryDbHookMixin, SQLCheckOperator, _BigQueryOperatorsEncryptionConfigurationMixin +): """Performs checks against BigQuery. This operator expects a SQL query that returns a single row. Each value on @@ -248,6 +266,13 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator): Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account. (templated) :param labels: a dictionary containing labels for the table, passed to BigQuery. + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + + .. code-block:: python + + encryption_configuration = { + "kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY", + } :param deferrable: Run operator in the deferrable mode. :param poll_interval: (Deferrable mode only) polling period in seconds to check for the status of job. @@ -272,6 +297,7 @@ def __init__( location: str | None = None, impersonation_chain: str | Sequence[str] | None = None, labels: dict | None = None, + encryption_configuration: dict | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), poll_interval: float = 4.0, **kwargs, @@ -282,6 +308,7 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain self.labels = labels + self.encryption_configuration = encryption_configuration self.deferrable = deferrable self.poll_interval = poll_interval @@ -293,6 +320,8 @@ def _submit_job( """Submit a new job and get the job id for polling the status using Trigger.""" configuration = {"query": {"query": self.sql, "useLegacySql": self.use_legacy_sql}} + self.include_encryption_configuration(configuration, "query") + return hook.insert_job( configuration=configuration, project_id=hook.project_id, @@ -767,7 +796,9 @@ def execute(self, context=None): self.log.info("All tests have passed") -class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator): +class BigQueryTableCheckOperator( + _BigQueryDbHookMixin, SQLTableCheckOperator, _BigQueryOperatorsEncryptionConfigurationMixin +): """ Subclasses the SQLTableCheckOperator in order to provide a job id for OpenLineage to parse. @@ -795,6 +826,13 @@ class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). :param labels: a dictionary containing labels for the table, passed to BigQuery + :param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys). + + .. code-block:: python + + encryption_configuration = { + "kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY", + } """ template_fields: Sequence[str] = tuple(set(SQLTableCheckOperator.template_fields) | {"gcp_conn_id"}) @@ -812,6 +850,7 @@ def __init__( location: str | None = None, impersonation_chain: str | Sequence[str] | None = None, labels: dict | None = None, + encryption_configuration: dict | None = None, **kwargs, ) -> None: super().__init__(table=table, checks=checks, partition_clause=partition_clause, **kwargs) @@ -820,6 +859,7 @@ def __init__( self.location = location self.impersonation_chain = impersonation_chain self.labels = labels + self.encryption_configuration = encryption_configuration def _submit_job( self, @@ -829,6 +869,8 @@ def _submit_job( """Submit a new job and get the job id for polling the status using Trigger.""" configuration = {"query": {"query": self.sql, "useLegacySql": self.use_legacy_sql}} + self.include_encryption_configuration(configuration, "query") + return hook.insert_job( configuration=configuration, project_id=hook.project_id, @@ -1222,7 +1264,7 @@ class BigQueryExecuteQueryOperator(GoogleCloudBaseOperator): .. code-block:: python encryption_configuration = { - "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key", + "kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY", } :param impersonation_chain: Optional service account to impersonate using short-term credentials, or chained list of accounts required to get the access_token @@ -1462,7 +1504,7 @@ class BigQueryCreateEmptyTableOperator(GoogleCloudBaseOperator): .. code-block:: python encryption_configuration = { - "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key", + "kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY", } :param location: The location used for the operation. :param cluster_fields: [Optional] The fields used for clustering. @@ -1690,7 +1732,7 @@ class BigQueryCreateExternalTableOperator(GoogleCloudBaseOperator): .. code-block:: python encryption_configuration = { - "kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key", + "kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY", } :param location: The location used for the operation. :param impersonation_chain: Optional service account to impersonate using short-term diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index a70d8f216437e..ba94347437eef 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -48,6 +48,7 @@ BigQueryInsertJobOperator, BigQueryIntervalCheckOperator, BigQueryPatchDatasetOperator, + BigQueryTableCheckOperator, BigQueryUpdateDatasetOperator, BigQueryUpdateTableOperator, BigQueryUpdateTableSchemaOperator, @@ -2443,3 +2444,49 @@ def test_bigquery_column_check_operator_fails( ) with pytest.raises(AirflowException): ti.task.execute(MagicMock()) + + +class TestBigQueryTableCheckOperator: + @mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryJob") + def test_encryption_configuration(self, mock_job, mock_hook): + encryption_configuration = { + "kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY", + } + + mock_job.result.return_value.to_dataframe.return_value = pd.DataFrame( + { + "check_name": ["row_count_check"], + "check_result": [1], + } + ) + mock_hook.return_value.insert_job.return_value = mock_job + mock_hook.return_value.project_id = TEST_GCP_PROJECT_ID + + check_statement = "COUNT(*) = 1" + operator = BigQueryTableCheckOperator( + task_id="TASK_ID", + table="test_table", + checks={"row_count_check": {"check_statement": check_statement}}, + encryption_configuration=encryption_configuration, + location=TEST_DATASET_LOCATION, + ) + + operator.execute(MagicMock()) + mock_hook.return_value.insert_job.assert_called_with( + configuration={ + "query": { + "query": f"""SELECT check_name, check_result FROM ( + SELECT 'row_count_check' AS check_name, MIN(row_count_check) AS check_result + FROM (SELECT CASE WHEN {check_statement} THEN 1 ELSE 0 END AS row_count_check + FROM test_table ) AS sq + ) AS check_table""", + "useLegacySql": True, + "destinationEncryptionConfiguration": encryption_configuration, + } + }, + project_id=TEST_GCP_PROJECT_ID, + location=TEST_DATASET_LOCATION, + job_id="", + nowait=False, + )