Skip to content

Commit

Permalink
apache#9941 Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yesemsanthoshkumar committed May 5, 2021
1 parent 33820d0 commit 9991a54
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 73 deletions.
47 changes: 10 additions & 37 deletions airflow/providers/google/cloud/operators/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,7 @@ def _build_cluster_data(self):

if self.init_actions_uris:
init_actions_dict = [
{
"executable_file": uri,
"execution_timeout": self._get_init_action_timeout(),
}
{'executable_file': uri, 'execution_timeout': self._get_init_action_timeout()}
for uri in self.init_actions_uris
]
cluster_data['initialization_actions'] = init_actions_dict
Expand Down Expand Up @@ -636,11 +633,7 @@ def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None:
gcs_uri = hook.diagnose_cluster(
region=self.region, cluster_name=self.cluster_name, project_id=self.project_id
)
self.log.info(
"Diagnostic information for cluster %s available at: %s",
self.cluster_name,
gcs_uri,
)
self.log.info('Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri)
if self.delete_on_error:
self._delete_cluster(hook)
raise AirflowException("Cluster was created but was in ERROR state.")
Expand Down Expand Up @@ -672,11 +665,8 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster:
return cluster

def execute(self, context) -> dict:
self.log.info("Creating cluster: %s", self.cluster_name)
hook = DataprocHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.log.info('Creating cluster: %s', self.cluster_name)
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
# Save data required to display extra link no matter what the cluster status will be
self.xcom_push(
context,
Expand Down Expand Up @@ -709,11 +699,6 @@ def execute(self, context) -> dict:
cluster = self._create_cluster(hook)
self._handle_error_state(hook, cluster)

self.xcom_push(
context,
key='cluster_conf',
value={'cluster_name': self.cluster_name, 'region': self.region, 'project_id': self.project_id},
)
return Cluster.to_dict(cluster)


Expand Down Expand Up @@ -844,15 +829,9 @@ def execute(self, context) -> None:
self.log.info("Scaling cluster: %s", self.cluster_name)

scaling_cluster_data = self._build_scale_cluster_data()
update_mask = [
"config.worker_config.num_instances",
"config.secondary_worker_config.num_instances",
]

hook = DataprocHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
update_mask = ["config.worker_config.num_instances", "config.secondary_worker_config.num_instances"]

hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
# Save data required to display extra link no matter what the cluster status will be
self.xcom_push(
context,
Expand Down Expand Up @@ -1119,9 +1098,7 @@ def on_kill(self) -> None:
"""
if self.dataproc_job_id:
self.hook.cancel_job(
project_id=self.project_id,
job_id=self.dataproc_job_id,
location=self.region,
project_id=self.project_id, job_id=self.dataproc_job_id, location=self.region
)


Expand Down Expand Up @@ -1635,9 +1612,7 @@ def generate_job(self):
# Check if the file is local, if that is the case, upload it to a bucket
if os.path.isfile(self.main):
cluster_info = self.hook.get_cluster(
project_id=self.hook.project_id,
region=self.region,
cluster_name=self.cluster_name,
project_id=self.hook.project_id, region=self.region, cluster_name=self.cluster_name
)
bucket = cluster_info['config']['config_bucket']
self.main = f"gs://{bucket}/{self.main}"
Expand All @@ -1654,9 +1629,7 @@ def execute(self, context):
# Check if the file is local, if that is the case, upload it to a bucket
if os.path.isfile(self.main):
cluster_info = self.hook.get_cluster(
project_id=self.hook.project_id,
region=self.region,
cluster_name=self.cluster_name,
project_id=self.hook.project_id, region=self.region, cluster_name=self.cluster_name
)
bucket = cluster_info['config']['config_bucket']
self.main = self._upload_file_temp(bucket, self.main)
Expand Down
49 changes: 13 additions & 36 deletions tests/providers/google/cloud/operators/test_dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,16 +424,13 @@ def test_execute(self, mock_hook, to_dict_mock):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context=self.mock_context)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_cluster.assert_called_once_with(**create_cluster_args)

# Test whether xcom push occurs before create cluster is called
self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False)

mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result())
mock_hook.return_value.create_cluster.assert_called_once_with(**create_cluster_args)
self.mock_ti.xcom_push.assert_called_once_with(
key="cluster_conf",
value=DATAPROC_CLUSTER_CONF_EXPECTED,
Expand All @@ -460,10 +457,7 @@ def test_execute_if_cluster_exists(self, mock_hook, to_dict_mock):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context=self.mock_context)
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.create_cluster.assert_called_once_with(
region=GCP_LOCATION,
project_id=GCP_PROJECT,
Expand Down Expand Up @@ -666,15 +660,12 @@ def test_execute(self, mock_hook):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context=self.mock_context)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args)

# Test whether xcom push occurs before cluster is updated
self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False)

mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args)
self.mock_ti.xcom_push.assert_called_once_with(
key="cluster_conf",
value=DATAPROC_CLUSTER_CONF_EXPECTED,
Expand Down Expand Up @@ -794,17 +785,15 @@ def test_execute(self, mock_hook):
)
op.execute(context=self.mock_context)

mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)

# Test whether xcom push occurs before polling for job
self.assertLess(
self.extra_links_manager_mock.mock_calls.index(xcom_push_call),
self.extra_links_manager_mock.mock_calls.index(wait_for_job_call),
msg='Xcom push for Job Link has to be done before polling for job status',
)

mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.return_value.submit_job.assert_called_once_with(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
Expand Down Expand Up @@ -980,15 +969,12 @@ def test_execute(self, mock_hook):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context=self.mock_context)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args)

# Test whether the xcom push happens before updating the cluster
self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False)

mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args)
self.mock_ti.xcom_push.assert_called_once_with(
key="cluster_conf",
value=DATAPROC_CLUSTER_CONF_EXPECTED,
Expand Down Expand Up @@ -1148,10 +1134,7 @@ def test_execute(self, mock_hook, mock_uuid):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context=MagicMock())
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.submit_job.assert_called_once_with(
project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION
)
Expand Down Expand Up @@ -1210,10 +1193,7 @@ def test_execute(self, mock_hook, mock_uuid):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context=MagicMock())
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.submit_job.assert_called_once_with(
project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION
)
Expand Down Expand Up @@ -1278,10 +1258,7 @@ def test_execute(self, mock_hook, mock_uuid):
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context=MagicMock())
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.submit_job.assert_called_once_with(
project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION
)
Expand All @@ -1306,7 +1283,7 @@ def test_execute_override_project_id(self, mock_hook, mock_uuid):
variables=self.variables,
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={})
op.execute(context=MagicMock())
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.submit_job.assert_called_once_with(
project_id="other-project", job=self.other_project_job, location=GCP_LOCATION
Expand Down

0 comments on commit 9991a54

Please sign in to comment.