Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use insert_job in the BigQueryToGCPOpertor and adjust links #24416

Merged
merged 3 commits into from
Jun 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 85 additions & 6 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
import hashlib
import json
import logging
import re
import time
import uuid
import warnings
from copy import deepcopy
from datetime import datetime, timedelta
Expand Down Expand Up @@ -1698,7 +1700,7 @@ def run_load(
f"Please only use one or more of the following options: {allowed_schema_update_options}"
)

destination_project, destination_dataset, destination_table = _split_tablename(
destination_project, destination_dataset, destination_table = self.split_tablename(
table_input=destination_project_dataset_table,
default_project_id=self.project_id,
var_name='destination_project_dataset_table',
Expand Down Expand Up @@ -1850,7 +1852,7 @@ def run_copy(

source_project_dataset_tables_fixup = []
for source_project_dataset_table in source_project_dataset_tables:
source_project, source_dataset, source_table = _split_tablename(
source_project, source_dataset, source_table = self.split_tablename(
table_input=source_project_dataset_table,
default_project_id=self.project_id,
var_name='source_project_dataset_table',
Expand All @@ -1859,7 +1861,7 @@ def run_copy(
{'projectId': source_project, 'datasetId': source_dataset, 'tableId': source_table}
)

destination_project, destination_dataset, destination_table = _split_tablename(
destination_project, destination_dataset, destination_table = self.split_tablename(
table_input=destination_project_dataset_table, default_project_id=self.project_id
)
configuration = {
Expand Down Expand Up @@ -1924,7 +1926,7 @@ def run_extract(
if not self.project_id:
raise ValueError("The project_id should be set")

source_project, source_dataset, source_table = _split_tablename(
source_project, source_dataset, source_table = self.split_tablename(
table_input=source_project_dataset_table,
default_project_id=self.project_id,
var_name='source_project_dataset_table',
Expand Down Expand Up @@ -2092,7 +2094,7 @@ def run_query(
)

if destination_dataset_table:
destination_project, destination_dataset, destination_table = _split_tablename(
destination_project, destination_dataset, destination_table = self.split_tablename(
table_input=destination_dataset_table, default_project_id=self.project_id
)

Expand Down Expand Up @@ -2180,6 +2182,83 @@ def run_query(
self.running_job_id = job.job_id
return job.job_id

def generate_job_id(self, job_id, dag_id, task_id, logical_date, configuration, force_rerun=False):
if force_rerun:
hash_base = str(uuid.uuid4())
else:
hash_base = json.dumps(configuration, sort_keys=True)

uniqueness_suffix = hashlib.md5(hash_base.encode()).hexdigest()

if job_id:
return f"{job_id}_{uniqueness_suffix}"

exec_date = logical_date.isoformat()
job_id = f"airflow_{dag_id}_{task_id}_{exec_date}_{uniqueness_suffix}"
return re.sub(r"[:\-+.]", "_", job_id)

def split_tablename(
self, table_input: str, default_project_id: str, var_name: Optional[str] = None
) -> Tuple[str, str, str]:

if '.' not in table_input:
raise ValueError(f'Expected table name in the format of <dataset>.<table>. Got: {table_input}')

if not default_project_id:
raise ValueError("INTERNAL: No default project is specified")

def var_print(var_name):
if var_name is None:
return ""
else:
return f"Format exception for {var_name}: "

if table_input.count('.') + table_input.count(':') > 3:
raise Exception(f'{var_print(var_name)}Use either : or . to specify project got {table_input}')
cmpt = table_input.rsplit(':', 1)
project_id = None
rest = table_input
if len(cmpt) == 1:
project_id = None
rest = cmpt[0]
elif len(cmpt) == 2 and cmpt[0].count(':') <= 1:
if cmpt[-1].count('.') != 2:
project_id = cmpt[0]
rest = cmpt[1]
else:
raise Exception(
f'{var_print(var_name)}Expect format of (<project:)<dataset>.<table>, got {table_input}'
)

cmpt = rest.split('.')
if len(cmpt) == 3:
if project_id:
raise ValueError(f"{var_print(var_name)}Use either : or . to specify project")
project_id = cmpt[0]
dataset_id = cmpt[1]
table_id = cmpt[2]

elif len(cmpt) == 2:
dataset_id = cmpt[0]
table_id = cmpt[1]
else:
raise Exception(
f'{var_print(var_name)} Expect format of (<project.|<project:)<dataset>.<table>, '
f'got {table_input}'
)

if project_id is None:
if var_name is not None:
self.log.info(
'Project not included in %s: %s; using project "%s"',
var_name,
table_input,
default_project_id,
)
project_id = default_project_id

return project_id, dataset_id, table_id


class BigQueryConnection:
"""
Expand Down Expand Up @@ -2771,7 +2850,7 @@ def _bq_cast(string_field: str, bq_type: str) -> Union[None, int, float, bool, s
return string_field


def _split_tablename(
def split_tablename(
table_input: str, default_project_id: str, var_name: Optional[str] = None
) -> Tuple[str, str, str]:

Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/links/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains Google BigQuery links."""
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from airflow.models import BaseOperator
from airflow.providers.google.cloud.links.base import BaseGoogleLink
Expand Down Expand Up @@ -66,9 +66,9 @@ class BigQueryTableLink(BaseGoogleLink):
def persist(
context: "Context",
task_instance: BaseOperator,
dataset_id: str,
project_id: str,
table_id: str,
dataset_id: Optional[str] = None,
):
task_instance.xcom_push(
context,
Expand Down
64 changes: 34 additions & 30 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,15 @@

"""This module contains Google BigQuery operators."""
import enum
import hashlib
import json
import re
import uuid
import warnings
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Sequence, Set, SupportsAbs, Union

import attr
from google.api_core.exceptions import Conflict
from google.api_core.retry import Retry
from google.cloud.bigquery import DEFAULT_RETRY
from google.cloud.bigquery import DEFAULT_RETRY, CopyJob, ExtractJob, LoadJob, QueryJob

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, BaseOperatorLink
Expand Down Expand Up @@ -2119,21 +2116,6 @@ def _handle_job_error(job: BigQueryJob) -> None:
if job.error_result:
raise AirflowException(f"BigQuery job {job.job_id} failed: {job.error_result}")

def _job_id(self, context):
if self.force_rerun:
hash_base = str(uuid.uuid4())
else:
hash_base = json.dumps(self.configuration, sort_keys=True)

uniqueness_suffix = hashlib.md5(hash_base.encode()).hexdigest()

if self.job_id:
return f"{self.job_id}_{uniqueness_suffix}"

exec_date = context['logical_date'].isoformat()
job_id = f"airflow_{self.dag_id}_{self.task_id}_{exec_date}_{uniqueness_suffix}"
return re.sub(r"[:\-+.]", "_", job_id)

def execute(self, context: Any):
hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
Expand All @@ -2142,7 +2124,14 @@ def execute(self, context: Any):
)
self.hook = hook

job_id = self._job_id(context)
job_id = hook.generate_job_id(
job_id=self.job_id,
dag_id=self.dag_id,
task_id=self.task_id,
logical_date=context["logical_date"],
configuration=self.configuration,
force_rerun=self.force_rerun,
)

try:
self.log.info(f"Executing: {self.configuration}")
Expand All @@ -2167,16 +2156,31 @@ def execute(self, context: Any):
f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
)

if "query" in job.to_api_repr()["configuration"]:
if "destinationTable" in job.to_api_repr()["configuration"]["query"]:
table = job.to_api_repr()["configuration"]["query"]["destinationTable"]
BigQueryTableLink.persist(
context=context,
task_instance=self,
dataset_id=table["datasetId"],
project_id=table["projectId"],
table_id=table["tableId"],
)
job_types = {
LoadJob._JOB_TYPE: ["sourceTable", "destinationTable"],
CopyJob._JOB_TYPE: ["sourceTable", "destinationTable"],
ExtractJob._JOB_TYPE: ["sourceTable"],
QueryJob._JOB_TYPE: ["destinationTable"],
}

if self.project_id:
for job_type, tables_prop in job_types.items():
job_configuration = job.to_api_repr()["configuration"]
if job_type in job_configuration:
for table_prop in tables_prop:
if table_prop in job_configuration[job_type]:
table = job_configuration[job_type][table_prop]
persist_kwargs = {
"context": context,
"task_instance": self,
"project_id": self.project_id,
"table_id": table,
}
if not isinstance(table, str):
persist_kwargs["table_id"] = table["tableId"]
persist_kwargs["dataset_id"] = table["datasetId"]

BigQueryTableLink.persist(**persist_kwargs)

self.job_id = job.job_id
return job.job_id
Expand Down
Loading