diff --git a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py index fc7ea9ba85445..c3e2bc369807c 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/vertex_ai/dataset.py @@ -410,7 +410,7 @@ def execute(self, context: Context): ) initial_dataset_size = self._get_number_of_ds_items( dataset=hook.get_dataset( - dataset_id=self.dataset_id, + dataset=self.dataset_id, project_id=self.project_id, region=self.region, retry=self.retry, @@ -432,7 +432,7 @@ def execute(self, context: Context): hook.wait_for_operation(timeout=self.timeout, operation=operation) result_dataset_size = self._get_number_of_ds_items( dataset=hook.get_dataset( - dataset_id=self.dataset_id, + dataset=self.dataset_id, project_id=self.project_id, region=self.region, retry=self.retry, diff --git a/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py b/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py index 47a19ed6b104c..95d649ccc9231 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py +++ b/providers/google/tests/unit/google/cloud/operators/test_vertex_ai.py @@ -1387,8 +1387,9 @@ def test_execute(self, mock_hook): FINAL_DS_SIZE = 101 INITIAL_DS = {**SAMPLE_DATASET, "data_item_count": INITIAL_DS_SIZE} FINAL_DS = {**SAMPLE_DATASET, "data_item_count": FINAL_DS_SIZE} + get_ds_mock = mock_hook.return_value.get_dataset - mock_hook.return_value.get_dataset.side_effect = [Dataset(INITIAL_DS), Dataset(FINAL_DS)] + get_ds_mock.side_effect = [Dataset(INITIAL_DS), Dataset(FINAL_DS)] res = op.execute(context={}) @@ -1404,6 +1405,22 @@ def test_execute(self, mock_hook): ) assert res["total_data_items_imported"] == FINAL_DS_SIZE - INITIAL_DS_SIZE + assert get_ds_mock.call_count == 2 + sample_get_ds_kwargs = dict( + region=GCP_LOCATION, + project_id=GCP_PROJECT, + dataset=TEST_DATASET_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + get_ds_mock.assert_has_calls( + [ + call(**sample_get_ds_kwargs), + call(**sample_get_ds_kwargs), + ] + ) + class TestVertexAIListDatasetsOperator: @mock.patch(VERTEX_AI_PATH.format("dataset.Dataset.to_dict"))