From 9991a5442e3c4fe09b661deb55dd11cbaa1e10fa Mon Sep 17 00:00:00 2001 From: yesemsanthoshkumar <8852302+yesemsanthoshkumar@users.noreply.github.com> Date: Thu, 6 May 2021 01:52:57 +0530 Subject: [PATCH] apache#9941 Fix tests --- .../google/cloud/operators/dataproc.py | 47 ++++-------------- .../google/cloud/operators/test_dataproc.py | 49 +++++-------------- 2 files changed, 23 insertions(+), 73 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index d8bcebe0dbe0b..d8df03a9505f0 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -428,10 +428,7 @@ def _build_cluster_data(self): if self.init_actions_uris: init_actions_dict = [ - { - "executable_file": uri, - "execution_timeout": self._get_init_action_timeout(), - } + {'executable_file': uri, 'execution_timeout': self._get_init_action_timeout()} for uri in self.init_actions_uris ] cluster_data['initialization_actions'] = init_actions_dict @@ -636,11 +633,7 @@ def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None: gcs_uri = hook.diagnose_cluster( region=self.region, cluster_name=self.cluster_name, project_id=self.project_id ) - self.log.info( - "Diagnostic information for cluster %s available at: %s", - self.cluster_name, - gcs_uri, - ) + self.log.info('Diagnostic information for cluster %s available at: %s', self.cluster_name, gcs_uri) if self.delete_on_error: self._delete_cluster(hook) raise AirflowException("Cluster was created but was in ERROR state.") @@ -672,11 +665,8 @@ def _wait_for_cluster_in_creating_state(self, hook: DataprocHook) -> Cluster: return cluster def execute(self, context) -> dict: - self.log.info("Creating cluster: %s", self.cluster_name) - hook = DataprocHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + self.log.info('Creating cluster: %s', self.cluster_name) + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) # Save data required to display extra link no matter what the cluster status will be self.xcom_push( context, @@ -709,11 +699,6 @@ def execute(self, context) -> dict: cluster = self._create_cluster(hook) self._handle_error_state(hook, cluster) - self.xcom_push( - context, - key='cluster_conf', - value={'cluster_name': self.cluster_name, 'region': self.region, 'project_id': self.project_id}, - ) return Cluster.to_dict(cluster) @@ -844,15 +829,9 @@ def execute(self, context) -> None: self.log.info("Scaling cluster: %s", self.cluster_name) scaling_cluster_data = self._build_scale_cluster_data() - update_mask = [ - "config.worker_config.num_instances", - "config.secondary_worker_config.num_instances", - ] - - hook = DataprocHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) + update_mask = ["config.worker_config.num_instances", "config.secondary_worker_config.num_instances"] + + hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) # Save data required to display extra link no matter what the cluster status will be self.xcom_push( context, @@ -1119,9 +1098,7 @@ def on_kill(self) -> None: """ if self.dataproc_job_id: self.hook.cancel_job( - project_id=self.project_id, - job_id=self.dataproc_job_id, - location=self.region, + project_id=self.project_id, job_id=self.dataproc_job_id, location=self.region ) @@ -1635,9 +1612,7 @@ def generate_job(self): # Check if the file is local, if that is the case, upload it to a bucket if os.path.isfile(self.main): cluster_info = self.hook.get_cluster( - project_id=self.hook.project_id, - region=self.region, - cluster_name=self.cluster_name, + project_id=self.hook.project_id, region=self.region, cluster_name=self.cluster_name ) bucket = cluster_info['config']['config_bucket'] self.main = f"gs://{bucket}/{self.main}" @@ -1654,9 +1629,7 @@ def execute(self, context): # Check if the file is local, if that is the case, upload it to a bucket if os.path.isfile(self.main): cluster_info = self.hook.get_cluster( - project_id=self.hook.project_id, - region=self.region, - cluster_name=self.cluster_name, + project_id=self.hook.project_id, region=self.region, cluster_name=self.cluster_name ) bucket = cluster_info['config']['config_bucket'] self.main = self._upload_file_temp(bucket, self.main) diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 4733f10a64f20..9a0ef21b7d327 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -424,16 +424,13 @@ def test_execute(self, mock_hook, to_dict_mock): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=self.mock_context) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.create_cluster.assert_called_once_with(**create_cluster_args) # Test whether xcom push occurs before create cluster is called self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result()) - mock_hook.return_value.create_cluster.assert_called_once_with(**create_cluster_args) self.mock_ti.xcom_push.assert_called_once_with( key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, @@ -460,10 +457,7 @@ def test_execute_if_cluster_exists(self, mock_hook, to_dict_mock): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=self.mock_context) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.create_cluster.assert_called_once_with( region=GCP_LOCATION, project_id=GCP_PROJECT, @@ -666,15 +660,12 @@ def test_execute(self, mock_hook): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=self.mock_context) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args) # Test whether xcom push occurs before cluster is updated self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args) self.mock_ti.xcom_push.assert_called_once_with( key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, @@ -794,6 +785,8 @@ def test_execute(self, mock_hook): ) op.execute(context=self.mock_context) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + # Test whether xcom push occurs before polling for job self.assertLess( self.extra_links_manager_mock.mock_calls.index(xcom_push_call), @@ -801,10 +794,6 @@ def test_execute(self, mock_hook): msg='Xcom push for Job Link has to be done before polling for job status', ) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, location=GCP_LOCATION, @@ -980,15 +969,12 @@ def test_execute(self, mock_hook): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=self.mock_context) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args) # Test whether the xcom push happens before updating the cluster self.extra_links_manager_mock.assert_has_calls(expected_calls, any_order=False) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) - mock_hook.return_value.update_cluster.assert_called_once_with(**update_cluster_args) self.mock_ti.xcom_push.assert_called_once_with( key="cluster_conf", value=DATAPROC_CLUSTER_CONF_EXPECTED, @@ -1148,10 +1134,7 @@ def test_execute(self, mock_hook, mock_uuid): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=MagicMock()) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION ) @@ -1210,10 +1193,7 @@ def test_execute(self, mock_hook, mock_uuid): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=MagicMock()) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION ) @@ -1278,10 +1258,7 @@ def test_execute(self, mock_hook, mock_uuid): impersonation_chain=IMPERSONATION_CHAIN, ) op.execute(context=MagicMock()) - mock_hook.assert_called_once_with( - gcp_conn_id=GCP_CONN_ID, - impersonation_chain=IMPERSONATION_CHAIN, - ) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, job=self.job, location=GCP_LOCATION ) @@ -1306,7 +1283,7 @@ def test_execute_override_project_id(self, mock_hook, mock_uuid): variables=self.variables, impersonation_chain=IMPERSONATION_CHAIN, ) - op.execute(context={}) + op.execute(context=MagicMock()) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.submit_job.assert_called_once_with( project_id="other-project", job=self.other_project_job, location=GCP_LOCATION