Skip to content

Commit

Permalink
Fix DataprocJobBaseOperator not being compatible with dotted names (a…
Browse files Browse the repository at this point in the history
…pache#23439).

 * job_name parameter is now sanitized, replacing dots by underscores.
  • Loading branch information
gmcrocetti committed May 20, 2022
1 parent baae70c commit fa151fc
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 22 deletions.
7 changes: 4 additions & 3 deletions airflow/providers/google/cloud/hooks/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
job_type: str,
properties: Optional[Dict[str, str]] = None,
) -> None:
name = task_id + "_" + str(uuid.uuid4())[:8]
name = f"{task_id.replace('.', '_')}_{uuid.uuid4()!s:.8}"
self.job_type = job_type
self.job = {
"job": {
Expand Down Expand Up @@ -175,11 +175,12 @@ def set_python_main(self, main: str) -> None:

def set_job_name(self, name: str) -> None:
"""
Set Dataproc job name.
Set Dataproc job name. Job name is sanitized, replacing dots by underscores.
:param name: Job name.
"""
self.job["job"]["reference"]["job_id"] = name + "_" + str(uuid.uuid4())[:8]
sanitized_name = f"{name.replace('.', '_')}_{uuid.uuid4()!s:.8}"
self.job["job"]["reference"]["job_id"] = sanitized_name

def build(self) -> Dict:
"""
Expand Down
32 changes: 21 additions & 11 deletions tests/providers/google/cloud/hooks/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pytest
from google.api_core.gapic_v1.method import DEFAULT
from google.cloud.dataproc_v1 import JobStatus
from parameterized import parameterized

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder
Expand Down Expand Up @@ -472,27 +473,28 @@ def setUp(self) -> None:
properties={"test": "test"},
)

@parameterized.expand([TASK_ID, f"group.{TASK_ID}"])
@mock.patch(DATAPROC_STRING.format("uuid.uuid4"))
def test_init(self, mock_uuid):
def test_init(self, job_name, mock_uuid):
mock_uuid.return_value = "uuid"
properties = {"test": "test"}
job = {
expected_job_id = f"{job_name}_{mock_uuid.return_value}".replace(".", "_")
expected_job = {
"job": {
"labels": {"airflow-version": AIRFLOW_VERSION},
"placement": {"cluster_name": CLUSTER_NAME},
"reference": {"job_id": TASK_ID + "_uuid", "project_id": GCP_PROJECT},
"reference": {"job_id": expected_job_id, "project_id": GCP_PROJECT},
"test": {"properties": properties},
}
}
builder = DataProcJobBuilder(
project_id=GCP_PROJECT,
task_id=TASK_ID,
task_id=job_name,
cluster_name=CLUSTER_NAME,
job_type="test",
properties=properties,
)

assert job == builder.job
assert expected_job == builder.job

def test_add_labels(self):
labels = {"key": "value"}
Expand Down Expand Up @@ -559,14 +561,22 @@ def test_set_python_main(self):
self.builder.set_python_main(main)
assert main == self.builder.job["job"][self.job_type]["main_python_file_uri"]

@parameterized.expand(
[
("simple", "name"),
("name with underscores", "name_with_dash"),
("name with dot", "group.name"),
("name with dot and underscores", "group.name_with_dash"),
]
)
@mock.patch(DATAPROC_STRING.format("uuid.uuid4"))
def test_set_job_name(self, mock_uuid):
def test_set_job_name(self, name, job_name, mock_uuid):
uuid = "test_uuid"
expected_job_name = f"{job_name}_{uuid[:8]}".replace(".", "_")
mock_uuid.return_value = uuid
name = "name"
self.builder.set_job_name(name)
name += "_" + uuid[:8]
assert name == self.builder.job["job"]["reference"]["job_id"]
self.builder.set_job_name(job_name)
assert expected_job_name == self.builder.job["job"]["reference"]["job_id"]
assert len(self.builder.job["job"]["reference"]["job_id"]) == len(job_name) + 9

def test_build(self):
assert self.builder.job == self.builder.build()
35 changes: 27 additions & 8 deletions tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,8 +1204,9 @@ class TestDataProcHiveOperator(unittest.TestCase):
query = "define sin HiveUDF('sin');"
variables = {"key": "value"}
job_id = "uuid_id"
job_name = "simple"
job = {
"reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
"reference": {"project_id": GCP_PROJECT, "job_id": f"{job_name}_{job_id}"},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"hive_job": {"query_list": {"queries": [query]}, "script_variables": variables},
Expand All @@ -1226,6 +1227,7 @@ def test_execute(self, mock_hook, mock_uuid):
mock_hook.return_value.submit_job.return_value.reference.job_id = self.job_id

op = DataprocSubmitHiveJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -1249,6 +1251,7 @@ def test_builder(self, mock_hook, mock_uuid):
mock_uuid.return_value = self.job_id

op = DataprocSubmitHiveJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -1263,8 +1266,9 @@ class TestDataProcPigOperator(unittest.TestCase):
query = "define sin HiveUDF('sin');"
variables = {"key": "value"}
job_id = "uuid_id"
job_name = "simple"
job = {
"reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
"reference": {"project_id": GCP_PROJECT, "job_id": f"{job_name}_{job_id}"},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"pig_job": {"query_list": {"queries": [query]}, "script_variables": variables},
Expand All @@ -1285,6 +1289,7 @@ def test_execute(self, mock_hook, mock_uuid):
mock_hook.return_value.submit_job.return_value.reference.job_id = self.job_id

op = DataprocSubmitPigJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -1308,6 +1313,7 @@ def test_builder(self, mock_hook, mock_uuid):
mock_uuid.return_value = self.job_id

op = DataprocSubmitPigJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -1321,15 +1327,16 @@ def test_builder(self, mock_hook, mock_uuid):
class TestDataProcSparkSqlOperator(unittest.TestCase):
query = "SHOW DATABASES;"
variables = {"key": "value"}
job_name = "simple"
job_id = "uuid_id"
job = {
"reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
"reference": {"project_id": GCP_PROJECT, "job_id": f"{job_name}_{job_id}"},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"spark_sql_job": {"query_list": {"queries": [query]}, "script_variables": variables},
}
other_project_job = {
"reference": {"project_id": "other-project", "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
"reference": {"project_id": "other-project", "job_id": f"{job_name}_{job_id}"},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"spark_sql_job": {"query_list": {"queries": [query]}, "script_variables": variables},
Expand All @@ -1350,6 +1357,7 @@ def test_execute(self, mock_hook, mock_uuid):
mock_hook.return_value.submit_job.return_value.reference.job_id = self.job_id

op = DataprocSubmitSparkSqlJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -1375,6 +1383,7 @@ def test_execute_override_project_id(self, mock_hook, mock_uuid):
mock_hook.return_value.submit_job.return_value.reference.job_id = self.job_id

op = DataprocSubmitSparkSqlJobOperator(
job_name=self.job_name,
project_id="other-project",
task_id=TASK_ID,
region=GCP_LOCATION,
Expand All @@ -1399,6 +1408,7 @@ def test_builder(self, mock_hook, mock_uuid):
mock_uuid.return_value = self.job_id

op = DataprocSubmitSparkSqlJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -1412,10 +1422,11 @@ def test_builder(self, mock_hook, mock_uuid):
class TestDataProcSparkOperator(DataprocJobTestBase):
main_class = "org.apache.spark.examples.SparkPi"
jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"]
job_name = "simple"
job = {
"reference": {
"project_id": GCP_PROJECT,
"job_id": "{{task.task_id}}_{{ds_nodash}}_" + TEST_JOB_ID,
"job_id": f"{job_name}_{TEST_JOB_ID}",
},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
Expand All @@ -1440,6 +1451,7 @@ def test_execute(self, mock_hook, mock_uuid):
self.extra_links_manager_mock.attach_mock(mock_hook, 'hook')

op = DataprocSubmitSparkJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand Down Expand Up @@ -1505,9 +1517,10 @@ def test_submit_spark_job_operator_extra_links(mock_hook, dag_maker, create_task
class TestDataProcHadoopOperator(unittest.TestCase):
args = ["wordcount", "gs://pub/shakespeare/rose.txt"]
jar = "file:///usr/lib/spark/examples/jars/spark-examples.jar"
job_name = "simple"
job_id = "uuid_id"
job = {
"reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
"reference": {"project_id": GCP_PROJECT, "job_id": f"{job_name}_{job_id}"},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"hadoop_job": {"main_jar_file_uri": jar, "args": args},
Expand All @@ -1529,6 +1542,7 @@ def test_execute(self, mock_hook, mock_uuid):
mock_uuid.return_value = self.job_id

op = DataprocSubmitHadoopJobOperator(
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -1542,8 +1556,9 @@ def test_execute(self, mock_hook, mock_uuid):
class TestDataProcPySparkOperator(unittest.TestCase):
uri = "gs://{}/{}"
job_id = "uuid_id"
job_name = "simple"
job = {
"reference": {"project_id": GCP_PROJECT, "job_id": "{{task.task_id}}_{{ds_nodash}}_" + job_id},
"reference": {"project_id": GCP_PROJECT, "job_id": f"{job_name}_{job_id}"},
"placement": {"cluster_name": "cluster-1"},
"labels": {"airflow-version": AIRFLOW_VERSION},
"pyspark_job": {"main_python_file_uri": uri},
Expand All @@ -1562,7 +1577,11 @@ def test_execute(self, mock_hook, mock_uuid):
mock_uuid.return_value = self.job_id

op = DataprocSubmitPySparkJobOperator(
task_id=TASK_ID, region=GCP_LOCATION, gcp_conn_id=GCP_CONN_ID, main=self.uri
job_name=self.job_name,
task_id=TASK_ID,
region=GCP_LOCATION,
gcp_conn_id=GCP_CONN_ID,
main=self.uri,
)
job = op.generate_job()
assert self.job == job
Expand Down

0 comments on commit fa151fc

Please sign in to comment.