Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add encryption_configuration parameter to BigQueryCheckOperator and BigQueryTableCheckOperator #39432

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,25 @@ def get_openlineage_facets_on_complete(self, task_instance):
)


class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator):
class _BigQueryOperatorsEncryptionConfigurationMixin:
"""A class to handle the configuration for BigQueryHook.insert_job method."""

# Note: If you want to add this feature to a new operator you can include the class name in the type
# annotation of the `self`. Then you can inherit this class in the target operator.
# e.g: BigQueryCheckOperator, BigQueryTableCheckOperator
def include_encryption_configuration( # type:ignore[misc]
self: BigQueryCheckOperator | BigQueryTableCheckOperator,
configuration: dict,
config_key: str,
) -> None:
"""Add encryption_configuration to destinationEncryptionConfiguration key if it is not None."""
if self.encryption_configuration is not None:
configuration[config_key]["destinationEncryptionConfiguration"] = self.encryption_configuration


class BigQueryCheckOperator(
_BigQueryDbHookMixin, SQLCheckOperator, _BigQueryOperatorsEncryptionConfigurationMixin
):
"""Performs checks against BigQuery.

This operator expects a SQL query that returns a single row. Each value on
Expand Down Expand Up @@ -248,6 +266,13 @@ class BigQueryCheckOperator(_BigQueryDbHookMixin, SQLCheckOperator):
Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account. (templated)
:param labels: a dictionary containing labels for the table, passed to BigQuery.
:param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys).

.. code-block:: python

encryption_configuration = {
"kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY",
}
:param deferrable: Run operator in the deferrable mode.
:param poll_interval: (Deferrable mode only) polling period in seconds to
check for the status of job.
Expand All @@ -272,6 +297,7 @@ def __init__(
location: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
labels: dict | None = None,
encryption_configuration: dict | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
poll_interval: float = 4.0,
**kwargs,
Expand All @@ -282,6 +308,7 @@ def __init__(
self.location = location
self.impersonation_chain = impersonation_chain
self.labels = labels
self.encryption_configuration = encryption_configuration
self.deferrable = deferrable
self.poll_interval = poll_interval

Expand All @@ -293,6 +320,8 @@ def _submit_job(
"""Submit a new job and get the job id for polling the status using Trigger."""
configuration = {"query": {"query": self.sql, "useLegacySql": self.use_legacy_sql}}

self.include_encryption_configuration(configuration, "query")

return hook.insert_job(
configuration=configuration,
project_id=hook.project_id,
Expand Down Expand Up @@ -767,7 +796,9 @@ def execute(self, context=None):
self.log.info("All tests have passed")


class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator):
class BigQueryTableCheckOperator(
_BigQueryDbHookMixin, SQLTableCheckOperator, _BigQueryOperatorsEncryptionConfigurationMixin
):
"""
Subclasses the SQLTableCheckOperator in order to provide a job id for OpenLineage to parse.

Expand Down Expand Up @@ -795,6 +826,13 @@ class BigQueryTableCheckOperator(_BigQueryDbHookMixin, SQLTableCheckOperator):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param labels: a dictionary containing labels for the table, passed to BigQuery
:param encryption_configuration: [Optional] Custom encryption configuration (e.g., Cloud KMS keys).

.. code-block:: python

encryption_configuration = {
"kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY",
}
"""

template_fields: Sequence[str] = tuple(set(SQLTableCheckOperator.template_fields) | {"gcp_conn_id"})
Expand All @@ -812,6 +850,7 @@ def __init__(
location: str | None = None,
impersonation_chain: str | Sequence[str] | None = None,
labels: dict | None = None,
encryption_configuration: dict | None = None,
**kwargs,
) -> None:
super().__init__(table=table, checks=checks, partition_clause=partition_clause, **kwargs)
Expand All @@ -820,6 +859,7 @@ def __init__(
self.location = location
self.impersonation_chain = impersonation_chain
self.labels = labels
self.encryption_configuration = encryption_configuration

def _submit_job(
self,
Expand All @@ -829,6 +869,8 @@ def _submit_job(
"""Submit a new job and get the job id for polling the status using Trigger."""
configuration = {"query": {"query": self.sql, "useLegacySql": self.use_legacy_sql}}

self.include_encryption_configuration(configuration, "query")

return hook.insert_job(
configuration=configuration,
project_id=hook.project_id,
Expand Down Expand Up @@ -1222,7 +1264,7 @@ class BigQueryExecuteQueryOperator(GoogleCloudBaseOperator):
.. code-block:: python

encryption_configuration = {
"kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key",
"kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY",
}
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
Expand Down Expand Up @@ -1462,7 +1504,7 @@ class BigQueryCreateEmptyTableOperator(GoogleCloudBaseOperator):
.. code-block:: python

encryption_configuration = {
"kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key",
"kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY",
}
:param location: The location used for the operation.
:param cluster_fields: [Optional] The fields used for clustering.
Expand Down Expand Up @@ -1690,7 +1732,7 @@ class BigQueryCreateExternalTableOperator(GoogleCloudBaseOperator):
.. code-block:: python

encryption_configuration = {
"kmsKeyName": "projects/testp/locations/us/keyRings/test-kr/cryptoKeys/test-key",
"kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY",
}
:param location: The location used for the operation.
:param impersonation_chain: Optional service account to impersonate using short-term
Expand Down
47 changes: 47 additions & 0 deletions tests/providers/google/cloud/operators/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
BigQueryInsertJobOperator,
BigQueryIntervalCheckOperator,
BigQueryPatchDatasetOperator,
BigQueryTableCheckOperator,
BigQueryUpdateDatasetOperator,
BigQueryUpdateTableOperator,
BigQueryUpdateTableSchemaOperator,
Expand Down Expand Up @@ -2443,3 +2444,49 @@ def test_bigquery_column_check_operator_fails(
)
with pytest.raises(AirflowException):
ti.task.execute(MagicMock())


class TestBigQueryTableCheckOperator:
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryJob")
def test_encryption_configuration(self, mock_job, mock_hook):
encryption_configuration = {
"kmsKeyName": "projects/PROJECT/locations/LOCATION/keyRings/KEY_RING/cryptoKeys/KEY",
}

mock_job.result.return_value.to_dataframe.return_value = pd.DataFrame(
{
"check_name": ["row_count_check"],
"check_result": [1],
}
)
mock_hook.return_value.insert_job.return_value = mock_job
mock_hook.return_value.project_id = TEST_GCP_PROJECT_ID

check_statement = "COUNT(*) = 1"
operator = BigQueryTableCheckOperator(
task_id="TASK_ID",
table="test_table",
checks={"row_count_check": {"check_statement": check_statement}},
encryption_configuration=encryption_configuration,
location=TEST_DATASET_LOCATION,
)

operator.execute(MagicMock())
mock_hook.return_value.insert_job.assert_called_with(
configuration={
"query": {
"query": f"""SELECT check_name, check_result FROM (
SELECT 'row_count_check' AS check_name, MIN(row_count_check) AS check_result
FROM (SELECT CASE WHEN {check_statement} THEN 1 ELSE 0 END AS row_count_check
FROM test_table ) AS sq
) AS check_table""",
"useLegacySql": True,
"destinationEncryptionConfiguration": encryption_configuration,
}
},
project_id=TEST_GCP_PROJECT_ID,
location=TEST_DATASET_LOCATION,
job_id="",
nowait=False,
)