Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix OpenLineage extraction for GCP deferrable operators #40521

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions airflow/providers/google/cloud/transfers/bigquery_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,6 @@ def __init__(
self.hook: BigQueryHook | None = None
self.deferrable = deferrable

self._job_id: str = ""

@staticmethod
def _handle_job_error(job: BigQueryJob | UnknownJob) -> None:
if job.error_result:
Expand Down Expand Up @@ -212,7 +210,7 @@ def execute(self, context: Context):
self.hook = hook

configuration = self._prepare_configuration()
job_id = hook.generate_job_id(
self.job_id = hook.generate_job_id(
job_id=self.job_id,
dag_id=self.dag_id,
task_id=self.task_id,
Expand All @@ -224,14 +222,14 @@ def execute(self, context: Context):
try:
self.log.info("Executing: %s", configuration)
job: BigQueryJob | UnknownJob = self._submit_job(
hook=hook, job_id=job_id, configuration=configuration
hook=hook, job_id=self.job_id, configuration=configuration
)
except Conflict:
# If the job already exists retrieve it
job = hook.get_job(
project_id=self.project_id,
location=self.location,
job_id=job_id,
job_id=self.job_id,
)
if job.state in self.reattach_states:
# We are reattaching to a job
Expand All @@ -240,12 +238,12 @@ def execute(self, context: Context):
else:
# Same job configuration so we need force_rerun
raise AirflowException(
f"Job with id: {job_id} already exists and is in {job.state} state. If you "
f"Job with id: {self.job_id} already exists and is in {job.state} state. If you "
f"want to force rerun it consider setting `force_rerun=True`."
f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
)

self._job_id = job.job_id
self.job_id = job.job_id
conf = job.to_api_repr()["configuration"]["extract"]["sourceTable"]
dataset_id, project_id, table_id = conf["datasetId"], conf["projectId"], conf["tableId"]
BigQueryTableLink.persist(
Expand All @@ -261,7 +259,7 @@ def execute(self, context: Context):
timeout=self.execution_timeout,
trigger=BigQueryInsertJobTrigger(
conn_id=self.gcp_conn_id,
job_id=self._job_id,
job_id=self.job_id,
project_id=self.project_id or self.hook.project_id,
location=self.location or self.hook.location,
impersonation_chain=self.impersonation_chain,
Expand All @@ -284,6 +282,8 @@ def execute_complete(self, context: Context, event: dict[str, Any]):
self.task_id,
event["message"],
)
# Save job_id as an attribute to be later used by listeners
self.job_id = event.get("job_id")

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as we will include final BQ job id."""
Expand All @@ -303,7 +303,15 @@ def get_openlineage_facets_on_complete(self, task_instance):
)
from airflow.providers.openlineage.extractors import OperatorLineage

table_object = self.hook.get_client(self.hook.project_id).get_table(self.source_project_dataset_table)
if not self.hook:
self.hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)

project_id = self.project_id or self.hook.project_id
table_object = self.hook.get_client(project_id).get_table(self.source_project_dataset_table)

input_dataset = Dataset(
namespace="bigquery",
Expand Down Expand Up @@ -347,9 +355,9 @@ def get_openlineage_facets_on_complete(self, task_instance):
output_datasets.append(dataset)

run_facets = {}
if self._job_id:
if self.job_id:
run_facets = {
"externalQuery": ExternalQueryRunFacet(externalQueryId=self._job_id, source="bigquery"),
"externalQuery": ExternalQueryRunFacet(externalQueryId=self.job_id, source="bigquery"),
}

return OperatorLineage(inputs=[input_dataset], outputs=output_datasets, run_facets=run_facets)
19 changes: 15 additions & 4 deletions airflow/providers/google/cloud/transfers/gcs_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,8 @@ def execute_complete(self, context: Context, event: dict[str, Any]):
self.task_id,
event["message"],
)
# Save job_id as an attribute to be later used by listeners
self.job_id = event.get("job_id")
return self._find_max_value_in_column()

def _find_max_value_in_column(self):
Expand Down Expand Up @@ -757,17 +759,26 @@ def get_openlineage_facets_on_complete(self, task_instance):
)
from airflow.providers.openlineage.extractors import OperatorLineage

table_object = self.hook.get_client(self.hook.project_id).get_table(
self.destination_project_dataset_table
)
if not self.hook:
self.hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)

project_id = self.project_id or self.hook.project_id
table_object = self.hook.get_client(project_id).get_table(self.destination_project_dataset_table)

output_dataset_facets = get_facets_from_bq_table(table_object)

source_objects = (
self.source_objects if isinstance(self.source_objects, list) else [self.source_objects]
)
input_dataset_facets = {
"schema": output_dataset_facets["schema"],
}
input_datasets = []
for blob in sorted(self.source_objects):
for blob in sorted(source_objects):
additional_facets = {}

if "*" in blob:
Expand Down
20 changes: 20 additions & 0 deletions tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,26 @@ def test_execute_deferrable_mode(self, mock_hook):
nowait=True,
)

def test_execute_complete_reassigns_job_id(self):
"""Assert that we use job_id from event after deferral."""

operator = BigQueryToGCSOperator(
project_id=JOB_PROJECT_ID,
task_id=TASK_ID,
source_project_dataset_table=f"{PROJECT_ID}.{TEST_DATASET}.{TEST_TABLE_ID}",
destination_cloud_storage_uris=[f"gs://{TEST_BUCKET}/{TEST_FOLDER}/"],
deferrable=True,
job_id=None,
)
job_id = "123456"

assert operator.job_id is None
operator.execute_complete(
context=MagicMock(),
event={"status": "success", "message": "Job completed", "job_id": job_id},
)
assert operator.job_id == job_id

@pytest.mark.parametrize(
("gcs_uri", "expected_dataset_name"),
(
Expand Down
22 changes: 22 additions & 0 deletions tests/providers/google/cloud/transfers/test_gcs_to_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,6 +1913,28 @@ def test_schema_fields_int_without_external_table_async_should_execute_successfu

bq_hook.return_value.insert_job.assert_has_calls(calls)

@mock.patch(GCS_TO_BQ_PATH.format("BigQueryHook"))
def test_execute_complete_reassigns_job_id(self, bq_hook):
"""Assert that we use job_id from event after deferral."""

operator = GCSToBigQueryOperator(
task_id=TASK_ID,
bucket=TEST_BUCKET,
source_objects=TEST_SOURCE_OBJECTS,
destination_project_dataset_table=TEST_EXPLICIT_DEST,
deferrable=True,
job_id=None,
)
generated_job_id = "123456"

assert operator.job_id is None

operator.execute_complete(
context=MagicMock(),
event={"status": "success", "message": "Job completed", "job_id": generated_job_id},
)
assert operator.job_id == generated_job_id

def create_context(self, task):
dag = DAG(dag_id="dag")
logical_date = datetime(2022, 1, 1, 0, 0, 0)
Expand Down