Skip to content

Commit

Permalink
[AIRFLOW-4074] Cannot put labels on Cloud Dataproc jobs
Browse files Browse the repository at this point in the history
Add option to add labels to Dataproc jobs.

fixup! [AIRFLOW-4074] Cannot put labels on Cloud Dataproc jobs
  • Loading branch information
turbaszek committed Jul 20, 2019
1 parent 96933b0 commit 40128c1
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 4 deletions.
12 changes: 12 additions & 0 deletions airflow/contrib/hooks/gcp_dataproc_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from googleapiclient.discovery import build
from zope.deprecation import deprecation

from airflow.version import version
from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
from airflow.utils.log.logging_mixin import LoggingMixin

Expand Down Expand Up @@ -208,13 +209,24 @@ def __init__(self, project_id, task_id, cluster_name, job_type, properties):
"placement": {
"clusterName": cluster_name
},
"labels": {'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')},
job_type: {
}
}
}
if properties is not None:
self.job["job"][job_type]["properties"] = properties

def add_labels(self, labels):
"""
Set labels for Dataproc job.
:param labels: Labels for the job query.
:type labels: dict
"""
if labels:
self.job["job"]["labels"].update(labels)

def add_variables(self, variables):
"""
Set variables for Dataproc job.
Expand Down
9 changes: 8 additions & 1 deletion airflow/contrib/operators/dataproc_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,10 @@ class DataProcJobBaseOperator(BaseOperator):
For this to work, the service account making the request must have domain-wide
delegation enabled.
:type delegate_to: str
:param labels: The labels to associate with this job. Label keys must contain 1 to 63 characters,
and must conform to RFC 1035. Label values may be empty, but, if present, must contain 1 to 63
characters, and must conform to RFC 1035. No more than 32 labels can be associated with a job.
:type labels: dict
:param region: The specified region where the dataproc cluster is created.
:type region: str
:param job_error_states: Job states that should be considered error states.
Expand All @@ -658,13 +662,15 @@ def __init__(self,
dataproc_jars=None,
gcp_conn_id='google_cloud_default',
delegate_to=None,
labels=None,
region='global',
job_error_states=None,
*args,
**kwargs):
super().__init__(*args, **kwargs)
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.labels = labels
self.job_name = job_name
self.cluster_name = cluster_name
self.dataproc_properties = dataproc_properties
Expand All @@ -684,8 +690,9 @@ def create_job_template(self):
"""
self.job_template = self.hook.create_job_template(self.task_id, self.cluster_name, self.job_type,
self.dataproc_properties)
self.job_template.add_jar_file_uris(self.dataproc_jars)
self.job_template.set_job_name(self.job_name)
self.job_template.add_jar_file_uris(self.dataproc_jars)
self.job_template.add_labels(self.labels)

def execute(self, context):
if self.job_template:
Expand Down
125 changes: 122 additions & 3 deletions tests/contrib/operators/test_dataproc_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import re
import unittest
from unittest.mock import MagicMock, Mock, patch

from typing import Dict

import time
Expand Down Expand Up @@ -82,6 +81,12 @@
MAIN_URI = 'test-uri'
TEMPLATE_ID = 'template-id'

LABELS = {
'label_a': 'value_a',
'label_b': 'value_b',
'airflow-version': 'v' + version.replace('.', '-').replace('+', '-')
}

HOOK = 'airflow.contrib.operators.dataproc_operator.DataProcHook'
DATAPROC_JOB_ID = 'dataproc_job_id'
DATAPROC_JOB_TO_SUBMIT = {
Expand All @@ -92,7 +97,8 @@
},
'placement': {
'clusterName': CLUSTER_NAME
}
},
'labels': LABELS
}
}

Expand Down Expand Up @@ -563,7 +569,13 @@ def submit_side_effect(_1, _2, _3, _4):
with patch(HOOK) as mock_hook:
mock_hook = mock_hook()
mock_hook.submit.side_effect = submit_side_effect
mock_hook.create_job_template().build.return_value = {'job': {'reference': {'jobId': job_id}}}
mock_hook.create_job_template().build.return_value = {
'job': {
'reference': {
'jobId': job_id
}
}
}

task = DataProcJobBaseOperator(
task_id=TASK_ID,
Expand All @@ -580,6 +592,27 @@ def submit_side_effect(_1, _2, _3, _4):

class DataProcHadoopOperatorTest(unittest.TestCase):
# Unit test for the DataProcHadoopOperator
@mock.patch('airflow.contrib.operators.dataproc_operator.DataProcJobBaseOperator.execute')
@mock.patch('airflow.contrib.operators.dataproc_operator.uuid.uuid4', return_value='test')
def test_correct_job_definition(self, mock_hook, mock_uuid):
# Expected job
job_definition = deepcopy(DATAPROC_JOB_TO_SUBMIT)
job_definition['job']['hadoopJob'] = {'mainClass': None}
job_definition['job']['reference']['projectId'] = None
job_definition['job']['reference']['jobId'] = DATAPROC_JOB_ID + "_test"

# Prepare job using operator
task = DataProcHadoopOperator(
task_id=TASK_ID,
region=GCP_REGION,
cluster_name=CLUSTER_NAME,
job_name=DATAPROC_JOB_ID,
labels=LABELS
)

task.execute(context=None)
self.assertDictEqual(job_definition, task.job_template.job)

@staticmethod
def test_hook_correct_region():
with patch(HOOK) as mock_hook:
Expand All @@ -604,6 +637,27 @@ def test_dataproc_job_id_is_set():

class DataProcHiveOperatorTest(unittest.TestCase):
# Unit test for the DataProcHiveOperator
@mock.patch('airflow.contrib.operators.dataproc_operator.DataProcJobBaseOperator.execute')
@mock.patch('airflow.contrib.operators.dataproc_operator.uuid.uuid4', return_value='test')
def test_correct_job_definition(self, mock_hook, mock_uuid):
# Expected job
job_definition = deepcopy(DATAPROC_JOB_TO_SUBMIT)
job_definition['job']['hiveJob'] = {'queryFileUri': None}
job_definition['job']['reference']['projectId'] = None
job_definition['job']['reference']['jobId'] = DATAPROC_JOB_ID + "_test"

# Prepare job using operator
task = DataProcHiveOperator(
task_id=TASK_ID,
region=GCP_REGION,
cluster_name=CLUSTER_NAME,
job_name=DATAPROC_JOB_ID,
labels=LABELS
)

task.execute(context=None)
self.assertDictEqual(job_definition, task.job_template.job)

@staticmethod
def test_hook_correct_region():
with patch(HOOK) as mock_hook:
Expand All @@ -627,6 +681,27 @@ def test_dataproc_job_id_is_set():


class DataProcPigOperatorTest(unittest.TestCase):
@mock.patch('airflow.contrib.operators.dataproc_operator.DataProcJobBaseOperator.execute')
@mock.patch('airflow.contrib.operators.dataproc_operator.uuid.uuid4', return_value='test')
def test_correct_job_definition(self, mock_hook, mock_uuid):
# Expected job
job_definition = deepcopy(DATAPROC_JOB_TO_SUBMIT)
job_definition['job']['pigJob'] = {'queryFileUri': None}
job_definition['job']['reference']['projectId'] = None
job_definition['job']['reference']['jobId'] = DATAPROC_JOB_ID + "_test"

# Prepare job using operator
task = DataProcPigOperator(
task_id=TASK_ID,
region=GCP_REGION,
cluster_name=CLUSTER_NAME,
job_name=DATAPROC_JOB_ID,
labels=LABELS
)

task.execute(context=None)
self.assertDictEqual(job_definition, task.job_template.job)

@staticmethod
def test_hook_correct_region():
with patch(HOOK) as mock_hook:
Expand Down Expand Up @@ -655,6 +730,28 @@ def test_dataproc_job_id_is_set():

class DataProcPySparkOperatorTest(unittest.TestCase):
# Unit test for the DataProcPySparkOperator
@mock.patch('airflow.contrib.operators.dataproc_operator.DataProcJobBaseOperator.execute')
@mock.patch('airflow.contrib.operators.dataproc_operator.uuid.uuid4', return_value='test')
def test_correct_job_definition(self, mock_hook, mock_uuid):
# Expected job
job_definition = deepcopy(DATAPROC_JOB_TO_SUBMIT)
job_definition['job']['pysparkJob'] = {'mainPythonFileUri': 'main_class'}
job_definition['job']['reference']['projectId'] = None
job_definition['job']['reference']['jobId'] = DATAPROC_JOB_ID + "_test"

# Prepare job using operator
task = DataProcPySparkOperator(
task_id=TASK_ID,
region=GCP_REGION,
cluster_name=CLUSTER_NAME,
job_name=DATAPROC_JOB_ID,
labels=LABELS,
main="main_class"
)

task.execute(context=None)
self.assertDictEqual(job_definition, task.job_template.job)

@staticmethod
def test_hook_correct_region():
with patch(HOOK) as mock_hook:
Expand All @@ -681,6 +778,28 @@ def test_dataproc_job_id_is_set():

class DataProcSparkOperatorTest(unittest.TestCase):
# Unit test for the DataProcSparkOperator
@mock.patch('airflow.contrib.operators.dataproc_operator.DataProcJobBaseOperator.execute')
@mock.patch('airflow.contrib.operators.dataproc_operator.uuid.uuid4', return_value='test')
def test_correct_job_definition(self, mock_hook, mock_uuid):
# Expected job
job_definition = deepcopy(DATAPROC_JOB_TO_SUBMIT)
job_definition['job']['sparkJob'] = {'mainClass': 'main_class'}
job_definition['job']['reference']['projectId'] = None
job_definition['job']['reference']['jobId'] = DATAPROC_JOB_ID + "_test"

# Prepare job using operator
task = DataProcSparkOperator(
task_id=TASK_ID,
region=GCP_REGION,
cluster_name=CLUSTER_NAME,
job_name=DATAPROC_JOB_ID,
labels=LABELS,
main_class="main_class"
)

task.execute(context=None)
self.assertDictEqual(job_definition, task.job_template.job)

@staticmethod
def test_hook_correct_region():
with patch(HOOK) as mock_hook:
Expand Down

0 comments on commit 40128c1

Please sign in to comment.