diff --git a/airflow-core/tests/unit/always/test_project_structure.py b/airflow-core/tests/unit/always/test_project_structure.py
index 8da25c042c036..34abb250b9f92 100644
--- a/airflow-core/tests/unit/always/test_project_structure.py
+++ b/airflow-core/tests/unit/always/test_project_structure.py
@@ -136,7 +136,6 @@ def test_providers_modules_should_have_tests(self):
"providers/fab/tests/unit/fab/www/test_session.py",
"providers/fab/tests/unit/fab/www/test_views.py",
"providers/google/tests/unit/google/cloud/fs/test_gcs.py",
- "providers/google/tests/unit/google/cloud/links/test_automl.py",
"providers/google/tests/unit/google/cloud/links/test_base.py",
"providers/google/tests/unit/google/cloud/links/test_bigquery.py",
"providers/google/tests/unit/google/cloud/links/test_bigquery_dts.py",
@@ -162,6 +161,7 @@ def test_providers_modules_should_have_tests(self):
"providers/google/tests/unit/google/cloud/links/test_spanner.py",
"providers/google/tests/unit/google/cloud/links/test_stackdriver.py",
"providers/google/tests/unit/google/cloud/links/test_workflows.py",
+ "providers/google/tests/unit/google/cloud/links/test_translate.py",
"providers/google/tests/unit/google/cloud/operators/vertex_ai/test_auto_ml.py",
"providers/google/tests/unit/google/cloud/operators/vertex_ai/test_batch_prediction_job.py",
"providers/google/tests/unit/google/cloud/operators/vertex_ai/test_custom_job.py",
diff --git a/providers/google/docs/changelog.rst b/providers/google/docs/changelog.rst
index a61e04d64c271..d76bccc144645 100644
--- a/providers/google/docs/changelog.rst
+++ b/providers/google/docs/changelog.rst
@@ -27,6 +27,30 @@
Changelog
---------
+.. warning::
+ Deprecated classes, parameters and features have been removed from the Google provider package.
+ The following breaking changes were introduced:
+
+* Operators
+
+ * ``Remove AutoMLTrainModelOperator use airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTabularTrainingJobOperator, airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLVideoTrainingJobOperator, airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLImageTrainingJobOperator, airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator, airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator instead``
+ * ``Remove AutoMLPredictOperator use airflow.providers.google.cloud.operators.translate.TranslateTextOperator instead``
+ * ``Remove AutoMLCreateDatasetOperator use airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator, airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator instead``
+ * ``Remove AutoMLImportDataOperator use airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator, airflow.providers.google.cloud.operators.translate.TranslateImportDataOperator instead``
+ * ``Remove AutoMLTablesListColumnSpecsOperator because of the shutdown of legacy version of AutoML Tables``
+ * ``Remove AutoMLTablesUpdateDatasetOperator use airflow.providers.google.cloud.operators.vertex_ai.dataset.UpdateDatasetOperator instead``
+ * ``Remove AutoMLGetModelOperator use airflow.providers.google.cloud.operators.vertex_ai.model_service.GetModelOperator instead``
+ * ``Remove AutoMLDeleteModelOperator use airflow.providers.google.cloud.operators.vertex_ai.model_service.DeleteModelOperator, airflow.providers.google.cloud.operators.translate.TranslateDeleteModelOperator instead``
+ * ``Remove AutoMLDeployModelOperator use airflow.providers.google.cloud.operators.vertex_ai.endpoint_service.DeployModelOperator instead``
+ * ``Remove AutoMLTablesListTableSpecsOperator because of the shutdown of legacy version of AutoML Tables``
+ * ``Remove AutoMLListDatasetOperator use airflow.providers.google.cloud.operators.vertex_ai.dataset.ListDatasetsOperator, airflow.providers.google.cloud.operators.translate.TranslateDatasetsListOperator instead``
+ * ``Remove AutoMLDeleteDatasetOperator use airflow.providers.google.cloud.operators.vertex_ai.dataset.ListDatasetsOperator, airflow.providers.google.cloud.operators.translate.TranslateDatasetsListOperator instead``
+ * ``Remove MLEngineCreateModelOperator use appropriate VertexAI operator instead``
+
+* Hooks
+
+ * ``Remove CloudAutoMLHook use airflow.providers.google.cloud.hooks.vertex_ai.auto_ml.AutoMLHook, airflow.providers.google.cloud.hooks.translate.TranslateHook instead``
+
18.1.0
......
diff --git a/providers/google/docs/integration-logos/Cloud-AutoML.png b/providers/google/docs/integration-logos/Cloud-AutoML.png
deleted file mode 100644
index b147e074b58fd..0000000000000
Binary files a/providers/google/docs/integration-logos/Cloud-AutoML.png and /dev/null differ
diff --git a/providers/google/docs/operators/cloud/automl.rst b/providers/google/docs/operators/cloud/automl.rst
deleted file mode 100644
index bd243e1165daf..0000000000000
--- a/providers/google/docs/operators/cloud/automl.rst
+++ /dev/null
@@ -1,244 +0,0 @@
- .. Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- .. http://www.apache.org/licenses/LICENSE-2.0
-
- .. Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
-
-Google Cloud AutoML Operators
-=============================
-
-.. warning::
- The AutoML API is deprecated. Planned removal date is September 30, 2025, but some operators might be deleted
- earlier, according to the docs and deprecation warnings!
- The replacement suggestions can be found in the deprecation warnings or in the doc below.
- Please note that AutoML for translation API functionality has been moved to the Advanced Translation service,
- the operators can be found at ``airflow.providers.google.cloud.operators.translate`` module.
-
-The `Google Cloud AutoML `__
-makes the power of machine learning available to you even if you have limited knowledge
-of machine learning. You can use AutoML to build on Google's machine learning capabilities
-to create your own custom machine learning models that are tailored to your business needs,
-and then integrate those models into your applications and web sites.
-
-Prerequisite Tasks
-^^^^^^^^^^^^^^^^^^
-
-.. include:: /operators/_partials/prerequisite_tasks.rst
-
-.. _howto/operator:CloudAutoMLDocuments:
-.. _howto/operator:AutoMLCreateDatasetOperator:
-.. _howto/operator:AutoMLImportDataOperator:
-.. _howto/operator:AutoMLTablesUpdateDatasetOperator:
-
-Creating Datasets
-^^^^^^^^^^^^^^^^^
-
-To create a Google AutoML dataset you can use
-:class:`~airflow.providers.google.cloud.operators.automl.AutoMLCreateDatasetOperator`.
-The operator returns dataset id in :ref:`XCom ` under ``dataset_id`` key.
-
-This operator is deprecated when running for text, video and vision prediction and will be removed after September 30, 2025.
-All the functionality of legacy AutoML Natural Language, Vision, Video Intelligence and new features are
-available on the Vertex AI platform. Please use
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator`
-:class:`~airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator`.
-
-After creating a dataset you can use it to import some data using
-:class:`~airflow.providers.google.cloud.operators.automl.AutoMLImportDataOperator`.
-
-To update dataset you can use
-:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesUpdateDatasetOperator`.
-
-.. warning::
- This operator is deprecated when running for text, video and vision prediction and will be removed soon.
- All the functionality of legacy AutoML Natural Language, Vision, Video Intelligence and new features are
- available on the Vertex AI platform. Please use
- :class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.UpdateDatasetOperator`
-
-.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_dataset.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_update_dataset_operator]
- :end-before: [END how_to_cloud_vertex_ai_update_dataset_operator]
-
-.. _howto/operator:AutoMLTablesListTableSpecsOperator:
-.. _howto/operator:AutoMLTablesListColumnSpecsOperator:
-
-Listing Table And Columns Specs
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-To list table specs you can use
-:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesListTableSpecsOperator`.
-
-To list column specs you can use
-:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesListColumnSpecsOperator`.
-
-AutoML Tables related operators are deprecated. Please use related Vertex AI Tabular operators.
-
-.. _howto/operator:AutoMLTrainModelOperator:
-.. _howto/operator:AutoMLGetModelOperator:
-.. _howto/operator:AutoMLDeployModelOperator:
-.. _howto/operator:AutoMLDeleteModelOperator:
-
-Operations On Models
-^^^^^^^^^^^^^^^^^^^^
-
-To create a Google AutoML model you can use
-:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTrainModelOperator`.
-The operator will wait for the operation to complete. Additionally the operator
-returns the id of model in :ref:`XCom ` under ``model_id`` key.
-
-.. warning::
- This operator is deprecated when running for text, video and vision prediction and will be removed after September 30, 2025.
- All the functionality of legacy AutoML Natural Language, Vision, Video Intelligence and new features are
- available on the Vertex AI platform. Please use
- :class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTabularTrainingJobOperator`,
- :class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLVideoTrainingJobOperator`,
- :class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLImageTrainingJobOperator`,
- :class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator`,
- :class:`~airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator`.
-
-When running Vertex AI Operator for training data, please ensure that your data is correctly stored in Vertex AI
-datasets. To create and import data to the dataset please use
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator`
-and
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator`
-
-For the AutoML translation please use the
-:class:`~airflow.providers.google.cloud.operators.translate.TranslateTextOperator`
-or
-:class:`~airflow.providers.google.cloud.operators.translate.TranslateTextBatchOperator`.
-
-To get existing model one can use
-:class:`~airflow.providers.google.cloud.operators.automl.AutoMLGetModelOperator`.
-
-This operator deprecated for tables, video intelligence, vision and natural language is deprecated
-and will be removed after 31.03.2024. Please use
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.model_service.GetModelOperator` instead.
-You can find example on how to use VertexAI operators here:
-
-.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_model_service.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_get_model_operator]
- :end-before: [END how_to_cloud_vertex_ai_get_model_operator]
-
-Once a model is created it could be deployed using
-:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeployModelOperator`.
-
-This operator deprecated for tables, video intelligence, vision and natural language is deprecated
-and will be removed after 31.03.2024. Please use
-:class:`airflow.providers.google.cloud.operators.vertex_ai.endpoint_service.DeployModelOperator` instead.
-You can find example on how to use VertexAI operators here:
-
-.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_endpoint.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_deploy_model_operator]
- :end-before: [END how_to_cloud_vertex_ai_deploy_model_operator]
-
-If you wish to delete a model you can use
-:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteModelOperator`.
-
-This operator deprecated for tables, video intelligence, vision and natural language is deprecated
-and will be removed after 31.03.2024. Please use
-:class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.DeleteModelOperator` instead.
-You can find example on how to use VertexAI operators here:
-
-.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_model_service.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_delete_model_operator]
- :end-before: [END how_to_cloud_vertex_ai_delete_model_operator]
-
-.. _howto/operator:AutoMLPredictOperator:
-
-Making Predictions
-^^^^^^^^^^^^^^^^^^
-
-To obtain predictions from Google Cloud AutoML model you can use
-:class:`~airflow.providers.google.cloud.operators.automl.AutoMLPredictOperator`. In the first case
-the model must be deployed.
-
-
-For tables, video intelligence, vision and natural language you can use the following operators:
-
-:class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.CreateBatchPredictionJobOperator`,
-:class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.GetBatchPredictionJobOperator`,
-:class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.ListBatchPredictionJobsOperator`,
-:class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.DeleteBatchPredictionJobOperator`.
-You can find examples on how to use VertexAI operators here:
-
-.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_batch_prediction_job.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_create_batch_prediction_job_operator]
- :end-before: [END how_to_cloud_vertex_ai_create_batch_prediction_job_operator]
-
-.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_batch_prediction_job.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_list_batch_prediction_job_operator]
- :end-before: [END how_to_cloud_vertex_ai_list_batch_prediction_job_operator]
-
-.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_batch_prediction_job.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_delete_batch_prediction_job_operator]
- :end-before: [END how_to_cloud_vertex_ai_delete_batch_prediction_job_operator]
-
-.. _howto/operator:AutoMLListDatasetOperator:
-.. _howto/operator:AutoMLDeleteDatasetOperator:
-
-Listing And Deleting Datasets
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-
-You can get a list of AutoML datasets using
-:class:`~airflow.providers.google.cloud.operators.automl.AutoMLListDatasetOperator`. The operator returns list
-of datasets ids in :ref:`XCom ` under ``dataset_id_list`` key.
-
-This operator deprecated for tables, video intelligence, vision and natural language is deprecated
-and will be removed after 31.03.2024. Please use
-:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ListDatasetsOperator`,
-:class:`~airflow.providers.google.cloud.operators.translate.TranslateDatasetsListOperator`
-instead.
-You can find example on how to use VertexAI operators here:
-
-.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_dataset.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_list_dataset_operator]
- :end-before: [END how_to_cloud_vertex_ai_list_dataset_operator]
-
-To delete a dataset you can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator`.
-The delete operator allows also to pass list or coma separated string of datasets ids to be deleted.
-
-This operator deprecated for tables, video intelligence, vision and natural language is deprecated
-and will be removed after 31.03.2024. Please use
-:class:`airflow.providers.google.cloud.operators.vertex_ai.dataset.DeleteDatasetOperator` instead.
-You can find example on how to use VertexAI operators here:
-
-.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_dataset.py
- :language: python
- :dedent: 4
- :start-after: [START how_to_cloud_vertex_ai_delete_dataset_operator]
- :end-before: [END how_to_cloud_vertex_ai_delete_dataset_operator]
-
-Reference
-^^^^^^^^^
-
-For further information, look at:
-
-* `Client Library Documentation `__
-* `Product Documentation `__
diff --git a/providers/google/docs/operators/cloud/mlengine.rst b/providers/google/docs/operators/cloud/mlengine.rst
index 1fc63b9528d76..407954011efe2 100644
--- a/providers/google/docs/operators/cloud/mlengine.rst
+++ b/providers/google/docs/operators/cloud/mlengine.rst
@@ -45,22 +45,8 @@ the Vertex AI platform.
Creating a model
^^^^^^^^^^^^^^^^
-A model is a container that can hold multiple model versions. A new model can be created through the
-:class:`~airflow.providers.google.cloud.operators.mlengine.MLEngineCreateModelOperator`.
-The ``model`` field should be defined with a dictionary containing the information about the model.
-``name`` is a required field in this dictionary.
-
-.. warning::
- This operator is deprecated. The model is created as a result of running Vertex AI operators that create training jobs
- of any types. For example, you can use
- :class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomPythonPackageTrainingJobOperator`.
- The result of running this operator will be ready-to-use model saved in Model Registry.
-
-.. exampleinclude:: /../../google/tests/system/google/cloud/ml_engine/example_mlengine.py
- :language: python
- :dedent: 4
- :start-after: [START howto_operator_create_custom_python_training_job_v1]
- :end-before: [END howto_operator_create_custom_python_training_job_v1]
+This function is deprecated. All the functionality of legacy MLEngine and new features are available on
+the Vertex AI platform.
Getting a model
^^^^^^^^^^^^^^^
@@ -72,18 +58,13 @@ Creating model versions
This function is deprecated. All the functionality of legacy MLEngine and new features are available on
the Vertex AI platform.
-For model versioning please check:
-`Model versioning with Vertex AI
-`__
Managing model versions
^^^^^^^^^^^^^^^^^^^^^^^
This function is deprecated. All the functionality of legacy MLEngine and new features are available on
the Vertex AI platform.
-For model versioning please check:
-`Model versioning with Vertex AI
-`__
+
Making predictions
^^^^^^^^^^^^^^^^^^
@@ -103,14 +84,3 @@ Evaluating a model
This function is deprecated. All the functionality of legacy MLEngine and new features are available on
the Vertex AI platform.
-To create and view Model Evaluation, please check the documentation:
-`Evaluate models using Vertex AI
-`__
-
-Reference
-^^^^^^^^^
-
-For further information, look at:
-
-* `Client Library Documentation `__
-* `Product Documentation `__
diff --git a/providers/google/provider.yaml b/providers/google/provider.yaml
index 09f5ee4ff87af..5c98aacfa6529 100644
--- a/providers/google/provider.yaml
+++ b/providers/google/provider.yaml
@@ -127,12 +127,6 @@ integrations:
how-to-guide:
- /docs/apache-airflow-providers-google/operators/ads.rst
tags: [gmp]
- - integration-name: Google AutoML
- external-doc-url: https://cloud.google.com/automl/
- how-to-guide:
- - /docs/apache-airflow-providers-google/operators/cloud/automl.rst
- logo: /docs/integration-logos/Cloud-AutoML.png
- tags: [gcp]
- integration-name: Google BigQuery Data Transfer Service
external-doc-url: https://cloud.google.com/bigquery/transfer/
logo: /docs/integration-logos/BigQuery.png
@@ -473,9 +467,6 @@ operators:
- integration-name: Google Cloud AlloyDB
python-modules:
- airflow.providers.google.cloud.operators.alloy_db
- - integration-name: Google AutoML
- python-modules:
- - airflow.providers.google.cloud.operators.automl
- integration-name: Google BigQuery
python-modules:
- airflow.providers.google.cloud.operators.bigquery
@@ -539,9 +530,6 @@ operators:
- integration-name: Google Kubernetes Engine
python-modules:
- airflow.providers.google.cloud.operators.kubernetes_engine
- - integration-name: Google Machine Learning Engine
- python-modules:
- - airflow.providers.google.cloud.operators.mlengine
- integration-name: Google Cloud Natural Language
python-modules:
- airflow.providers.google.cloud.operators.natural_language
@@ -729,9 +717,6 @@ hooks:
- integration-name: Google Ads
python-modules:
- airflow.providers.google.ads.hooks.ads
- - integration-name: Google AutoML
- python-modules:
- - airflow.providers.google.cloud.hooks.automl
- integration-name: Google BigQuery
python-modules:
- airflow.providers.google.cloud.hooks.bigquery
diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/automl.py b/providers/google/src/airflow/providers/google/cloud/hooks/automl.py
deleted file mode 100644
index a60262c21b900..0000000000000
--- a/providers/google/src/airflow/providers/google/cloud/hooks/automl.py
+++ /dev/null
@@ -1,673 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module contains a Google AutoML hook.
-
-.. spelling:word-list::
-
- PredictResponse
-"""
-
-from __future__ import annotations
-
-from collections.abc import Sequence
-from functools import cached_property
-from typing import TYPE_CHECKING
-
-from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
-from google.cloud.automl_v1beta1 import (
- AutoMlClient,
- BatchPredictInputConfig,
- BatchPredictOutputConfig,
- Dataset,
- ExamplePayload,
- ImageObjectDetectionModelDeploymentMetadata,
- InputConfig,
- Model,
- PredictionServiceClient,
- PredictResponse,
-)
-
-from airflow.exceptions import AirflowProviderDeprecationWarning
-from airflow.providers.google.common.consts import CLIENT_INFO
-from airflow.providers.google.common.deprecated import deprecated
-from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
-from airflow.providers.google.common.hooks.operation_helpers import OperationHelper
-
-if TYPE_CHECKING:
- from google.api_core.operation import Operation
- from google.api_core.retry import Retry
- from google.cloud.automl_v1beta1.services.auto_ml.pagers import (
- ListColumnSpecsPager,
- ListDatasetsPager,
- ListTableSpecsPager,
- )
- from google.protobuf.field_mask_pb2 import FieldMask
-
-
-@deprecated(
- planned_removal_date="September 30, 2025",
- use_instead="airflow.providers.google.cloud.hooks.vertex_ai.auto_ml.AutoMLHook, "
- "airflow.providers.google.cloud.hooks.translate.TranslateHook",
- category=AirflowProviderDeprecationWarning,
-)
-class CloudAutoMLHook(GoogleBaseHook, OperationHelper):
- """
- Google Cloud AutoML hook.
-
- All the methods in the hook where project_id is used must be called with
- keyword arguments rather than positional.
- """
-
- def __init__(
- self,
- gcp_conn_id: str = "google_cloud_default",
- impersonation_chain: str | Sequence[str] | None = None,
- **kwargs,
- ) -> None:
- super().__init__(
- gcp_conn_id=gcp_conn_id,
- impersonation_chain=impersonation_chain,
- **kwargs,
- )
- self._client: AutoMlClient | None = None
-
- @staticmethod
- def extract_object_id(obj: dict) -> str:
- """Return unique id of the object."""
- return obj["name"].rpartition("/")[-1]
-
- def get_conn(self) -> AutoMlClient:
- """
- Retrieve connection to AutoML.
-
- :return: Google Cloud AutoML client object.
- """
- if self._client is None:
- self._client = AutoMlClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
- return self._client
-
- @cached_property
- def prediction_client(self) -> PredictionServiceClient:
- """
- Creates PredictionServiceClient.
-
- :return: Google Cloud AutoML PredictionServiceClient client object.
- """
- return PredictionServiceClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
-
- @GoogleBaseHook.fallback_to_default_project_id
- def create_model(
- self,
- model: dict | Model,
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- timeout: float | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- retry: Retry | _MethodDefault = DEFAULT,
- ) -> Operation:
- """
- Create a model_id and returns a Model in the `response` field when it completes.
-
- When you create a model, several model evaluations are created for it:
- a global evaluation, and one evaluation for each annotation spec.
-
- :param model: The model_id to create. If a dict is provided, it must be of the same form
- as the protobuf message `google.cloud.automl_v1beta1.types.Model`
- :param project_id: ID of the Google Cloud project where model will be created if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests
- will not be retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete.
- Note that if `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
-
- :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
- """
- client = self.get_conn()
- parent = f"projects/{project_id}/locations/{location}"
- return client.create_model(
- request={"parent": parent, "model": model},
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
-
- @GoogleBaseHook.fallback_to_default_project_id
- def batch_predict(
- self,
- model_id: str,
- input_config: dict | BatchPredictInputConfig,
- output_config: dict | BatchPredictOutputConfig,
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- params: dict[str, str] | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- timeout: float | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- ) -> Operation:
- """
- Perform a batch prediction and returns a long-running operation object.
-
- Unlike the online `Predict`, batch prediction result won't be immediately
- available in the response. Instead, a long-running operation object is returned.
-
- :param model_id: Name of the model_id requested to serve the batch prediction.
- :param input_config: Required. The input configuration for batch prediction.
- If a dict is provided, it must be of the same form as the protobuf message
- `google.cloud.automl_v1beta1.types.BatchPredictInputConfig`
- :param output_config: Required. The Configuration specifying where output predictions should be
- written. If a dict is provided, it must be of the same form as the protobuf message
- `google.cloud.automl_v1beta1.types.BatchPredictOutputConfig`
- :param params: Additional domain-specific parameters for the predictions, any string must be up to
- 25000 characters long.
- :param project_id: ID of the Google Cloud project where model is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
-
- :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
- """
- client = self.prediction_client
- name = f"projects/{project_id}/locations/{location}/models/{model_id}"
- result = client.batch_predict(
- request={
- "name": name,
- "input_config": input_config,
- "output_config": output_config,
- "params": params,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def predict(
- self,
- model_id: str,
- payload: dict | ExamplePayload,
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- params: dict[str, str] | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- timeout: float | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- ) -> PredictResponse:
- """
- Perform an online prediction and returns the prediction result in the response.
-
- :param model_id: Name of the model_id requested to serve the prediction.
- :param payload: Required. Payload to perform a prediction on. The payload must match the problem type
- that the model_id was trained to solve. If a dict is provided, it must be of
- the same form as the protobuf message `google.cloud.automl_v1beta1.types.ExamplePayload`
- :param params: Additional domain-specific parameters, any string must be up to 25000 characters long.
- :param project_id: ID of the Google Cloud project where model is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
-
- :return: `google.cloud.automl_v1beta1.types.PredictResponse` instance
- """
- client = self.prediction_client
- name = f"projects/{project_id}/locations/{location}/models/{model_id}"
- result = client.predict(
- request={"name": name, "payload": payload, "params": params},
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def create_dataset(
- self,
- dataset: dict | Dataset,
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- retry: Retry | _MethodDefault = DEFAULT,
- timeout: float | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- ) -> Dataset:
- """
- Create a dataset.
-
- :param dataset: The dataset to create. If a dict is provided, it must be of the
- same form as the protobuf message Dataset.
- :param project_id: ID of the Google Cloud project where dataset is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
-
- :return: `google.cloud.automl_v1beta1.types.Dataset` instance.
- """
- client = self.get_conn()
- parent = f"projects/{project_id}/locations/{location}"
- result = client.create_dataset(
- request={"parent": parent, "dataset": dataset},
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def import_data(
- self,
- dataset_id: str,
- location: str,
- input_config: dict | InputConfig,
- project_id: str = PROVIDE_PROJECT_ID,
- retry: Retry | _MethodDefault = DEFAULT,
- timeout: float | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- ) -> Operation:
- """
- Import data into a dataset. For Tables this method can only be called on an empty Dataset.
-
- :param dataset_id: Name of the AutoML dataset.
- :param input_config: The desired input location and its domain specific semantics, if any.
- If a dict is provided, it must be of the same form as the protobuf message InputConfig.
- :param project_id: ID of the Google Cloud project where dataset is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
-
- :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
- """
- client = self.get_conn()
- name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
- result = client.import_data(
- request={"name": name, "input_config": input_config},
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def list_column_specs(
- self,
- dataset_id: str,
- table_spec_id: str,
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- field_mask: dict | FieldMask | None = None,
- filter_: str | None = None,
- page_size: int | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- timeout: float | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- ) -> ListColumnSpecsPager:
- """
- List column specs in a table spec.
-
- :param dataset_id: Name of the AutoML dataset.
- :param table_spec_id: table_spec_id for path builder.
- :param field_mask: Mask specifying which fields to read. If a dict is provided, it must be of the same
- form as the protobuf message `google.cloud.automl_v1beta1.types.FieldMask`
- :param filter_: Filter expression, see go/filtering.
- :param page_size: The maximum number of resources contained in the
- underlying API response. If page streaming is performed per
- resource, this parameter does not affect the return value. If page
- streaming is performed per-page, this determines the maximum number
- of resources in a page.
- :param project_id: ID of the Google Cloud project where dataset is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
-
- :return: `google.cloud.automl_v1beta1.types.ColumnSpec` instance.
- """
- client = self.get_conn()
- parent = client.table_spec_path(
- project=project_id,
- location=location,
- dataset=dataset_id,
- table_spec=table_spec_id,
- )
- result = client.list_column_specs(
- request={"parent": parent, "field_mask": field_mask, "filter": filter_, "page_size": page_size},
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def get_model(
- self,
- model_id: str,
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- retry: Retry | _MethodDefault = DEFAULT,
- timeout: float | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- ) -> Model:
- """
- Get a AutoML model.
-
- :param model_id: Name of the model.
- :param project_id: ID of the Google Cloud project where model is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
-
- :return: `google.cloud.automl_v1beta1.types.Model` instance.
- """
- client = self.get_conn()
- name = f"projects/{project_id}/locations/{location}/models/{model_id}"
- result = client.get_model(
- request={"name": name},
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def delete_model(
- self,
- model_id: str,
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- retry: Retry | _MethodDefault = DEFAULT,
- timeout: float | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- ) -> Operation:
- """
- Delete a AutoML model.
-
- :param model_id: Name of the model.
- :param project_id: ID of the Google Cloud project where model is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
-
- :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance.
- """
- client = self.get_conn()
- name = f"projects/{project_id}/locations/{location}/models/{model_id}"
- result = client.delete_model(
- request={"name": name},
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- def update_dataset(
- self,
- dataset: dict | Dataset,
- update_mask: dict | FieldMask | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- timeout: float | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- ) -> Dataset:
- """
- Update a dataset.
-
- :param dataset: The dataset which replaces the resource on the server.
- If a dict is provided, it must be of the same form as the protobuf message Dataset.
- :param update_mask: The update mask applies to the resource. If a dict is provided, it must
- be of the same form as the protobuf message FieldMask.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
-
- :return: `google.cloud.automl_v1beta1.types.Dataset` instance..
- """
- client = self.get_conn()
- result = client.update_dataset(
- request={"dataset": dataset, "update_mask": update_mask},
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def deploy_model(
- self,
- model_id: str,
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- image_detection_metadata: ImageObjectDetectionModelDeploymentMetadata | dict | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- timeout: float | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- ) -> Operation:
- """
- Deploys a model.
-
- If a model is already deployed, deploying it with the same parameters
- has no effect. Deploying with different parameters (as e.g. changing node_number) will
- reset the deployment state without pausing the model_id's availability.
-
- Only applicable for Text Classification, Image Object Detection and Tables; all other
- domains manage deployment automatically.
-
- :param model_id: Name of the model requested to serve the prediction.
- :param image_detection_metadata: Model deployment metadata specific to Image Object Detection.
- If a dict is provided, it must be of the same form as the protobuf message
- ImageObjectDetectionModelDeploymentMetadata
- :param project_id: ID of the Google Cloud project where model will be created if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
-
- :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance.
- """
- client = self.get_conn()
- name = f"projects/{project_id}/locations/{location}/models/{model_id}"
- result = client.deploy_model(
- request={
- "name": name,
- "image_object_detection_model_deployment_metadata": image_detection_metadata,
- },
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- def list_table_specs(
- self,
- dataset_id: str,
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- filter_: str | None = None,
- page_size: int | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- timeout: float | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- ) -> ListTableSpecsPager:
- """
- List table specs in a dataset_id.
-
- :param dataset_id: Name of the dataset.
- :param filter_: Filter expression, see go/filtering.
- :param page_size: The maximum number of resources contained in the
- underlying API response. If page streaming is performed per
- resource, this parameter does not affect the return value. If page
- streaming is performed per-page, this determines the maximum number
- of resources in a page.
- :param project_id: ID of the Google Cloud project where dataset is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
-
- :return: A `google.gax.PageIterator` instance. By default, this
- is an iterable of `google.cloud.automl_v1beta1.types.TableSpec` instances.
- This object can also be configured to iterate over the pages
- of the response through the `options` parameter.
- """
- client = self.get_conn()
- parent = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
- result = client.list_table_specs(
- request={"parent": parent, "filter": filter_, "page_size": page_size},
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def list_datasets(
- self,
- location: str,
- project_id: str,
- retry: Retry | _MethodDefault = DEFAULT,
- timeout: float | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- ) -> ListDatasetsPager:
- """
- List datasets in a project.
-
- :param project_id: ID of the Google Cloud project where dataset is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
-
- :return: A `google.gax.PageIterator` instance. By default, this
- is an iterable of `google.cloud.automl_v1beta1.types.Dataset` instances.
- This object can also be configured to iterate over the pages
- of the response through the `options` parameter.
- """
- client = self.get_conn()
- parent = f"projects/{project_id}/locations/{location}"
- result = client.list_datasets(
- request={"parent": parent},
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def delete_dataset(
- self,
- dataset_id: str,
- location: str,
- project_id: str,
- retry: Retry | _MethodDefault = DEFAULT,
- timeout: float | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- ) -> Operation:
- """
- Delete a dataset and all of its contents.
-
- :param dataset_id: ID of dataset to be deleted.
- :param project_id: ID of the Google Cloud project where dataset is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
-
- :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance
- """
- client = self.get_conn()
- name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
- result = client.delete_dataset(
- request={"name": name},
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
- @GoogleBaseHook.fallback_to_default_project_id
- def get_dataset(
- self,
- dataset_id: str,
- location: str,
- project_id: str,
- retry: Retry | _MethodDefault = DEFAULT,
- timeout: float | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- ) -> Dataset:
- """
- Retrieve the dataset for the given dataset_id.
-
- :param dataset_id: ID of dataset to be retrieved.
- :param location: The location of the project.
- :param project_id: ID of the Google Cloud project where dataset is located if None then
- default project_id is used.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
-
- :return: `google.cloud.automl_v1beta1.types.dataset.Dataset` instance.
- """
- client = self.get_conn()
- name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}"
- return client.get_dataset(
- request={"name": name},
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
diff --git a/providers/google/src/airflow/providers/google/cloud/operators/automl.py b/providers/google/src/airflow/providers/google/cloud/operators/automl.py
deleted file mode 100644
index 905b0acf2c074..0000000000000
--- a/providers/google/src/airflow/providers/google/cloud/operators/automl.py
+++ /dev/null
@@ -1,1364 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module contains Google AutoML operators."""
-
-from __future__ import annotations
-
-import ast
-from collections.abc import Sequence
-from functools import cached_property
-from typing import TYPE_CHECKING, cast
-
-from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
-from google.cloud.automl_v1beta1 import (
- ColumnSpec,
- Dataset,
- Model,
- PredictResponse,
- TableSpec,
-)
-
-from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
-from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
-from airflow.providers.google.cloud.hooks.vertex_ai.prediction_service import PredictionServiceHook
-from airflow.providers.google.cloud.links.translate import (
- TranslationDatasetListLink,
- TranslationLegacyDatasetLink,
- TranslationLegacyModelLink,
- TranslationLegacyModelPredictLink,
- TranslationLegacyModelTrainLink,
-)
-from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
-from airflow.providers.google.common.deprecated import deprecated
-from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
-
-if TYPE_CHECKING:
- from google.api_core.retry import Retry
-
- from airflow.providers.common.compat.sdk import Context
-
-MetaData = Sequence[tuple[str, str]]
-
-
-@deprecated(
- planned_removal_date="September 30, 2025",
- use_instead="airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTabularTrainingJobOperator, "
- "airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLVideoTrainingJobOperator, "
- "airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLImageTrainingJobOperator, "
- "airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator, "
- "airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator",
- category=AirflowProviderDeprecationWarning,
-)
-class AutoMLTrainModelOperator(GoogleCloudBaseOperator):
- """
- Creates Google Cloud AutoML model.
-
- .. warning::
- AutoMLTrainModelOperator for tables, video intelligence, vision and natural language has been deprecated
- and no longer available. Please use
- :class:`airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTabularTrainingJobOperator`,
- :class:`airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLVideoTrainingJobOperator`,
- :class:`airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLImageTrainingJobOperator`,
- :class:`airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator`,
- :class:`airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator`.
- instead.
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:AutoMLTrainModelOperator`
-
- :param model: Model definition.
- :param project_id: ID of the Google Cloud project where model will be created if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
- :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields: Sequence[str] = (
- "model",
- "location",
- "project_id",
- "impersonation_chain",
- )
- operator_extra_links = (
- TranslationLegacyModelTrainLink(),
- TranslationLegacyModelLink(),
- )
-
- def __init__(
- self,
- *,
- model: dict,
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- metadata: MetaData = (),
- timeout: float | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- gcp_conn_id: str = "google_cloud_default",
- impersonation_chain: str | Sequence[str] | None = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.model = model
- self.location = location
- self.project_id = project_id
- self.metadata = metadata
- self.timeout = timeout
- self.retry = retry
- self.gcp_conn_id = gcp_conn_id
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: Context):
- hook = CloudAutoMLHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
- self.log.info("Creating model %s...", self.model["display_name"])
- operation = hook.create_model(
- model=self.model,
- location=self.location,
- project_id=self.project_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- project_id = self.project_id or hook.project_id
- if project_id:
- TranslationLegacyModelTrainLink.persist(
- context=context,
- dataset_id=self.model["dataset_id"],
- project_id=project_id,
- location=self.location,
- )
- operation_result = hook.wait_for_operation(timeout=self.timeout, operation=operation)
- result = Model.to_dict(operation_result)
- model_id = hook.extract_object_id(result)
- self.log.info("Model is created, model_id: %s", model_id)
-
- context["task_instance"].xcom_push(key="model_id", value=model_id)
- if project_id:
- TranslationLegacyModelLink.persist(
- context=context,
- dataset_id=self.model["dataset_id"] or "-",
- model_id=model_id,
- project_id=project_id,
- location=self.location,
- )
- return result
-
-
-@deprecated(
- planned_removal_date="September 30, 2025",
- use_instead="airflow.providers.google.cloud.operators.translate.TranslateTextOperator",
- category=AirflowProviderDeprecationWarning,
-)
-class AutoMLPredictOperator(GoogleCloudBaseOperator):
- """
- Runs prediction operation on Google Cloud AutoML.
-
- .. warning::
- AutoMLPredictOperator for text, image, and video prediction has been deprecated.
- Please use endpoint_id param instead of model_id param.
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:AutoMLPredictOperator`
-
- :param model_id: Name of the model requested to serve the batch prediction.
- :param endpoint_id: Name of the endpoint used for the prediction.
- :param payload: Name of the model used for the prediction.
- :param project_id: ID of the Google Cloud project where model is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param operation_params: Additional domain-specific parameters for the predictions.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
- :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields: Sequence[str] = (
- "model_id",
- "location",
- "project_id",
- "impersonation_chain",
- )
- operator_extra_links = (TranslationLegacyModelPredictLink(),)
-
- def __init__(
- self,
- *,
- model_id: str | None = None,
- endpoint_id: str | None = None,
- location: str,
- payload: dict,
- operation_params: dict[str, str] | None = None,
- instances: list[str] | None = None,
- project_id: str = PROVIDE_PROJECT_ID,
- metadata: MetaData = (),
- timeout: float | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- gcp_conn_id: str = "google_cloud_default",
- impersonation_chain: str | Sequence[str] | None = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
-
- self.model_id = model_id
- self.endpoint_id = endpoint_id
- self.operation_params = operation_params
- self.instances = instances
- self.location = location
- self.project_id = project_id
- self.metadata = metadata
- self.timeout = timeout
- self.retry = retry
- self.payload = payload
- self.gcp_conn_id = gcp_conn_id
- self.impersonation_chain = impersonation_chain
-
- @cached_property
- def hook(self) -> CloudAutoMLHook | PredictionServiceHook:
- if self.model_id:
- return CloudAutoMLHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
- # endpoint_id defined
- return PredictionServiceHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
-
- @cached_property
- def model(self) -> Model | None:
- if self.model_id:
- hook = cast("CloudAutoMLHook", self.hook)
- return hook.get_model(
- model_id=self.model_id,
- location=self.location,
- project_id=self.project_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- return None
-
- def execute(self, context: Context):
- if self.model_id is None and self.endpoint_id is None:
- raise AirflowException("You must specify model_id or endpoint_id!")
-
- hook = self.hook
- if self.model_id:
- result = hook.predict(
- model_id=self.model_id,
- payload=self.payload,
- location=self.location,
- project_id=self.project_id,
- params=self.operation_params,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- else: # self.endpoint_id is defined
- result = hook.predict(
- endpoint_id=self.endpoint_id,
- instances=self.instances,
- payload=self.payload,
- location=self.location,
- project_id=self.project_id,
- parameters=self.operation_params,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
-
- project_id = self.project_id or hook.project_id
- dataset_id: str | None = self.model.dataset_id if self.model else None
- if project_id and self.model_id and dataset_id:
- TranslationLegacyModelPredictLink.persist(
- context=context,
- model_id=self.model_id,
- dataset_id=dataset_id,
- project_id=project_id,
- location=self.location,
- )
- return PredictResponse.to_dict(result)
-
-
-@deprecated(
- planned_removal_date="September 30, 2025",
- use_instead="airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator, "
- "airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator",
- category=AirflowProviderDeprecationWarning,
-)
-class AutoMLCreateDatasetOperator(GoogleCloudBaseOperator):
- """
- Creates a Google Cloud AutoML dataset.
-
- AutoMLCreateDatasetOperator for tables, video intelligence, vision and natural language has been
- deprecated and no longer available. Please use
- :class:`airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator`,
- :class:`airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator` instead.
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:AutoMLCreateDatasetOperator`
-
- :param dataset: The dataset to create. If a dict is provided, it must be of the
- same form as the protobuf message Dataset.
- :param project_id: ID of the Google Cloud project where dataset is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param params: Additional domain-specific parameters for the predictions.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
- :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields: Sequence[str] = (
- "dataset",
- "location",
- "project_id",
- "impersonation_chain",
- )
- operator_extra_links = (TranslationLegacyDatasetLink(),)
-
- def __init__(
- self,
- *,
- dataset: dict,
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- metadata: MetaData = (),
- timeout: float | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- gcp_conn_id: str = "google_cloud_default",
- impersonation_chain: str | Sequence[str] | None = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
-
- self.dataset = dataset
- self.location = location
- self.project_id = project_id
- self.metadata = metadata
- self.timeout = timeout
- self.retry = retry
- self.gcp_conn_id = gcp_conn_id
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: Context):
- hook = CloudAutoMLHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
- self.log.info("Creating dataset %s...", self.dataset)
- result = hook.create_dataset(
- dataset=self.dataset,
- location=self.location,
- project_id=self.project_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- result = Dataset.to_dict(result)
- dataset_id = hook.extract_object_id(result)
- self.log.info("Creating completed. Dataset id: %s", dataset_id)
-
- context["task_instance"].xcom_push(key="dataset_id", value=dataset_id)
- project_id = self.project_id or hook.project_id
- if project_id:
- TranslationLegacyDatasetLink.persist(
- context=context,
- dataset_id=dataset_id,
- project_id=project_id,
- location=self.location,
- )
- return result
-
-
-@deprecated(
- planned_removal_date="September 30, 2025",
- use_instead="airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator, "
- "airflow.providers.google.cloud.operators.translate.TranslateImportDataOperator",
- category=AirflowProviderDeprecationWarning,
-)
-class AutoMLImportDataOperator(GoogleCloudBaseOperator):
- """
- Imports data to a Google Cloud AutoML dataset.
-
- .. warning::
- AutoMLImportDataOperator for tables, video intelligence, vision and natural language has been deprecated
- and no longer available. Please use
- :class:`airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator` instead.
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:AutoMLImportDataOperator`
-
- :param dataset_id: ID of dataset to be updated.
- :param input_config: The desired input location and its domain specific semantics, if any.
- If a dict is provided, it must be of the same form as the protobuf message InputConfig.
- :param project_id: ID of the Google Cloud project where dataset is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param params: Additional domain-specific parameters for the predictions.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
- :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields: Sequence[str] = (
- "dataset_id",
- "input_config",
- "location",
- "project_id",
- "impersonation_chain",
- )
- operator_extra_links = (TranslationLegacyDatasetLink(),)
-
- def __init__(
- self,
- *,
- dataset_id: str,
- location: str,
- input_config: dict,
- project_id: str = PROVIDE_PROJECT_ID,
- metadata: MetaData = (),
- timeout: float | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- gcp_conn_id: str = "google_cloud_default",
- impersonation_chain: str | Sequence[str] | None = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
-
- self.dataset_id = dataset_id
- self.input_config = input_config
- self.location = location
- self.project_id = project_id
- self.metadata = metadata
- self.timeout = timeout
- self.retry = retry
- self.gcp_conn_id = gcp_conn_id
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: Context):
- hook = CloudAutoMLHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
- hook.get_dataset(
- dataset_id=self.dataset_id,
- location=self.location,
- project_id=self.project_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- self.log.info("Importing data to dataset...")
- operation = hook.import_data(
- dataset_id=self.dataset_id,
- input_config=self.input_config,
- location=self.location,
- project_id=self.project_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- hook.wait_for_operation(timeout=self.timeout, operation=operation)
- self.log.info("Import is completed")
- project_id = self.project_id or hook.project_id
- if project_id:
- TranslationLegacyDatasetLink.persist(
- context=context,
- dataset_id=self.dataset_id,
- project_id=project_id,
- location=self.location,
- )
-
-
-@deprecated(
- planned_removal_date="September 30, 2025",
- category=AirflowProviderDeprecationWarning,
- reason="Shutdown of legacy version of AutoML Tables on March 31, 2024.",
-)
-class AutoMLTablesListColumnSpecsOperator(GoogleCloudBaseOperator):
- """
- Lists column specs in a table.
-
- .. warning::
- Operator AutoMLTablesListColumnSpecsOperator has been deprecated due to shutdown of
- a legacy version of AutoML Tables on March 31, 2024. For additional information
- see: https://cloud.google.com/automl-tables/docs/deprecations.
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:AutoMLTablesListColumnSpecsOperator`
-
- :param dataset_id: Name of the dataset.
- :param table_spec_id: table_spec_id for path builder.
- :param field_mask: Mask specifying which fields to read. If a dict is provided, it must be of the same
- form as the protobuf message `google.cloud.automl_v1beta1.types.FieldMask`
- :param filter_: Filter expression, see go/filtering.
- :param page_size: The maximum number of resources contained in the
- underlying API response. If page streaming is performed per
- resource, this parameter does not affect the return value. If page
- streaming is performed per page, this determines the maximum number
- of resources in a page.
- :param project_id: ID of the Google Cloud project where dataset is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
- :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields: Sequence[str] = (
- "dataset_id",
- "table_spec_id",
- "field_mask",
- "filter_",
- "location",
- "project_id",
- "impersonation_chain",
- )
- operator_extra_links = (TranslationLegacyDatasetLink(),)
-
- def __init__(
- self,
- *,
- dataset_id: str,
- table_spec_id: str,
- location: str,
- field_mask: dict | None = None,
- filter_: str | None = None,
- page_size: int | None = None,
- project_id: str = PROVIDE_PROJECT_ID,
- metadata: MetaData = (),
- timeout: float | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- gcp_conn_id: str = "google_cloud_default",
- impersonation_chain: str | Sequence[str] | None = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.dataset_id = dataset_id
- self.table_spec_id = table_spec_id
- self.field_mask = field_mask
- self.filter_ = filter_
- self.page_size = page_size
- self.location = location
- self.project_id = project_id
- self.metadata = metadata
- self.timeout = timeout
- self.retry = retry
- self.gcp_conn_id = gcp_conn_id
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: Context):
- hook = CloudAutoMLHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
- self.log.info("Requesting column specs.")
- page_iterator = hook.list_column_specs(
- dataset_id=self.dataset_id,
- table_spec_id=self.table_spec_id,
- field_mask=self.field_mask,
- filter_=self.filter_,
- page_size=self.page_size,
- location=self.location,
- project_id=self.project_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- result = [ColumnSpec.to_dict(spec) for spec in page_iterator]
- self.log.info("Columns specs obtained.")
- project_id = self.project_id or hook.project_id
- if project_id:
- TranslationLegacyDatasetLink.persist(
- context=context,
- dataset_id=self.dataset_id,
- project_id=project_id,
- location=self.location,
- )
- return result
-
-
-@deprecated(
- planned_removal_date="September 30, 2025",
- use_instead="airflow.providers.google.cloud.operators.vertex_ai.dataset.UpdateDatasetOperator",
- category=AirflowProviderDeprecationWarning,
- reason="Shutdown of legacy version of AutoML Tables on March 31, 2024.",
-)
-class AutoMLTablesUpdateDatasetOperator(GoogleCloudBaseOperator):
- """
- Updates a dataset.
-
- .. warning::
- Operator AutoMLTablesUpdateDatasetOperator has been deprecated due to shutdown of
- a legacy version of AutoML Tables on March 31, 2024. For additional information
- see: https://cloud.google.com/automl-tables/docs/deprecations.
- Please use :class:`airflow.providers.google.cloud.operators.vertex_ai.dataset.UpdateDatasetOperator`
- instead.
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:AutoMLTablesUpdateDatasetOperator`
-
- :param dataset: The dataset which replaces the resource on the server.
- If a dict is provided, it must be of the same form as the protobuf message Dataset.
- :param update_mask: The update mask applies to the resource. If a dict is provided, it must
- be of the same form as the protobuf message FieldMask.
- :param location: The location of the project.
- :param params: Additional domain-specific parameters for the predictions.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
- :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields: Sequence[str] = (
- "dataset",
- "update_mask",
- "location",
- "impersonation_chain",
- )
- operator_extra_links = (TranslationLegacyDatasetLink(),)
-
- def __init__(
- self,
- *,
- dataset: dict,
- location: str,
- update_mask: dict | None = None,
- metadata: MetaData = (),
- timeout: float | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- gcp_conn_id: str = "google_cloud_default",
- impersonation_chain: str | Sequence[str] | None = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
-
- self.dataset = dataset
- self.update_mask = update_mask
- self.location = location
- self.metadata = metadata
- self.timeout = timeout
- self.retry = retry
- self.gcp_conn_id = gcp_conn_id
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: Context):
- hook = CloudAutoMLHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
- self.log.info("Updating AutoML dataset %s.", self.dataset["name"])
- result = hook.update_dataset(
- dataset=self.dataset,
- update_mask=self.update_mask,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- self.log.info("Dataset updated.")
- project_id = hook.project_id
- if project_id:
- TranslationLegacyDatasetLink.persist(
- context=context,
- dataset_id=hook.extract_object_id(self.dataset),
- project_id=project_id,
- location=self.location,
- )
- return Dataset.to_dict(result)
-
-
-@deprecated(
- planned_removal_date="September 30, 2025",
- use_instead="airflow.providers.google.cloud.operators.vertex_ai.model_service.GetModelOperator",
- category=AirflowProviderDeprecationWarning,
-)
-class AutoMLGetModelOperator(GoogleCloudBaseOperator):
- """
- Get Google Cloud AutoML model.
-
- .. warning::
- AutoMLGetModelOperator for tables, video intelligence, vision and natural language has been deprecated
- and no longer available. Please use
- :class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.GetModelOperator` instead.
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:AutoMLGetModelOperator`
-
- :param model_id: Name of the model requested to serve the prediction.
- :param project_id: ID of the Google Cloud project where model is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param params: Additional domain-specific parameters for the predictions.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
- :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields: Sequence[str] = (
- "model_id",
- "location",
- "project_id",
- "impersonation_chain",
- )
- operator_extra_links = (TranslationLegacyModelLink(),)
-
- def __init__(
- self,
- *,
- model_id: str,
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- metadata: MetaData = (),
- timeout: float | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- gcp_conn_id: str = "google_cloud_default",
- impersonation_chain: str | Sequence[str] | None = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
-
- self.model_id = model_id
- self.location = location
- self.project_id = project_id
- self.metadata = metadata
- self.timeout = timeout
- self.retry = retry
- self.gcp_conn_id = gcp_conn_id
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: Context):
- hook = CloudAutoMLHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
- result = hook.get_model(
- model_id=self.model_id,
- location=self.location,
- project_id=self.project_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- model = Model.to_dict(result)
- project_id = self.project_id or hook.project_id
- if project_id:
- TranslationLegacyModelLink.persist(
- context=context,
- dataset_id=model["dataset_id"],
- model_id=self.model_id,
- project_id=project_id,
- location=self.location,
- )
- return model
-
-
-@deprecated(
- planned_removal_date="September 30, 2025",
- use_instead="airflow.providers.google.cloud.operators.vertex_ai.model_service.DeleteModelOperator, "
- "airflow.providers.google.cloud.operators.translate.TranslateDeleteModelOperator",
- category=AirflowProviderDeprecationWarning,
-)
-class AutoMLDeleteModelOperator(GoogleCloudBaseOperator):
- """
- Delete Google Cloud AutoML model.
-
- .. warning::
- AutoMLDeleteModelOperator for tables, video intelligence, vision and natural language has been deprecated
- and no longer available. Please use
- :class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.DeleteModelOperator` instead.
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:AutoMLDeleteModelOperator`
-
- :param model_id: Name of the model requested to serve the prediction.
- :param project_id: ID of the Google Cloud project where model is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param params: Additional domain-specific parameters for the predictions.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
- :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields: Sequence[str] = (
- "model_id",
- "location",
- "project_id",
- "impersonation_chain",
- )
-
- def __init__(
- self,
- *,
- model_id: str,
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- metadata: MetaData = (),
- timeout: float | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- gcp_conn_id: str = "google_cloud_default",
- impersonation_chain: str | Sequence[str] | None = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
-
- self.model_id = model_id
- self.location = location
- self.project_id = project_id
- self.metadata = metadata
- self.timeout = timeout
- self.retry = retry
- self.gcp_conn_id = gcp_conn_id
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: Context):
- hook = CloudAutoMLHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
- hook.get_model(
- model_id=self.model_id,
- location=self.location,
- project_id=self.project_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- operation = hook.delete_model(
- model_id=self.model_id,
- location=self.location,
- project_id=self.project_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- hook.wait_for_operation(timeout=self.timeout, operation=operation)
- self.log.info("Deletion is completed")
-
-
-@deprecated(
- planned_removal_date="September 30, 2025",
- use_instead="airflow.providers.google.cloud.operators.vertex_ai.endpoint_service.DeployModelOperator",
- category=AirflowProviderDeprecationWarning,
-)
-class AutoMLDeployModelOperator(GoogleCloudBaseOperator):
- """
- Deploys a model; if a model is already deployed, deploying it with the same parameters has no effect.
-
- Deploying with different parameters (as e.g. changing node_number) will
- reset the deployment state without pausing the model_id's availability.
-
- Only applicable for Text Classification, Image Object Detection and Tables; all other
- domains manage deployment automatically.
-
- .. warning::
- Operator AutoMLDeployModelOperator has been deprecated due to shutdown of a legacy version
- of AutoML Natural Language, Vision, Video Intelligence on March 31, 2024.
- For additional information see: https://cloud.google.com/vision/automl/docs/deprecations .
- Please use :class:`airflow.providers.google.cloud.operators.vertex_ai.endpoint_service.DeployModelOperator`
- instead.
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:AutoMLDeployModelOperator`
-
- :param model_id: Name of the model to be deployed.
- :param image_detection_metadata: Model deployment metadata specific to Image Object Detection.
- If a dict is provided, it must be of the same form as the protobuf message
- ImageObjectDetectionModelDeploymentMetadata
- :param project_id: ID of the Google Cloud project where model is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param params: Additional domain-specific parameters for the predictions.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
- :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields: Sequence[str] = (
- "model_id",
- "location",
- "project_id",
- "impersonation_chain",
- )
-
- def __init__(
- self,
- *,
- model_id: str,
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- image_detection_metadata: dict | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- timeout: float | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- gcp_conn_id: str = "google_cloud_default",
- impersonation_chain: str | Sequence[str] | None = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
-
- self.model_id = model_id
- self.image_detection_metadata = image_detection_metadata
- self.location = location
- self.project_id = project_id
- self.metadata = metadata
- self.timeout = timeout
- self.retry = retry
- self.gcp_conn_id = gcp_conn_id
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: Context):
- hook = CloudAutoMLHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
- self.log.info("Deploying model_id %s", self.model_id)
- operation = hook.deploy_model(
- model_id=self.model_id,
- location=self.location,
- project_id=self.project_id,
- image_detection_metadata=self.image_detection_metadata,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- hook.wait_for_operation(timeout=self.timeout, operation=operation)
- self.log.info("Model was deployed successfully.")
-
-
-@deprecated(
- planned_removal_date="September 30, 2025",
- category=AirflowProviderDeprecationWarning,
- reason="Shutdown of legacy version of AutoML Tables on March 31, 2024.",
-)
-class AutoMLTablesListTableSpecsOperator(GoogleCloudBaseOperator):
- """
- Lists table specs in a dataset.
-
- .. warning::
- Operator AutoMLTablesListTableSpecsOperator has been deprecated due to shutdown of
- a legacy version of AutoML Tables on March 31, 2024. For additional information
- see: https://cloud.google.com/automl-tables/docs/deprecations.
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:AutoMLTablesListTableSpecsOperator`
-
- :param dataset_id: Name of the dataset.
- :param filter_: Filter expression, see go/filtering.
- :param page_size: The maximum number of resources contained in the
- underlying API response. If page streaming is performed per
- resource, this parameter does not affect the return value. If page
- streaming is performed per-page, this determines the maximum number
- of resources in a page.
- :param project_id: ID of the Google Cloud project if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
- :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields: Sequence[str] = (
- "dataset_id",
- "filter_",
- "location",
- "project_id",
- "impersonation_chain",
- )
- operator_extra_links = (TranslationLegacyDatasetLink(),)
-
- def __init__(
- self,
- *,
- dataset_id: str,
- location: str,
- page_size: int | None = None,
- filter_: str | None = None,
- project_id: str = PROVIDE_PROJECT_ID,
- metadata: MetaData = (),
- timeout: float | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- gcp_conn_id: str = "google_cloud_default",
- impersonation_chain: str | Sequence[str] | None = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.dataset_id = dataset_id
- self.filter_ = filter_
- self.page_size = page_size
- self.location = location
- self.project_id = project_id
- self.metadata = metadata
- self.timeout = timeout
- self.retry = retry
- self.gcp_conn_id = gcp_conn_id
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: Context):
- hook = CloudAutoMLHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
- self.log.info("Requesting table specs for %s.", self.dataset_id)
- page_iterator = hook.list_table_specs(
- dataset_id=self.dataset_id,
- filter_=self.filter_,
- page_size=self.page_size,
- location=self.location,
- project_id=self.project_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- result = [TableSpec.to_dict(spec) for spec in page_iterator]
- self.log.info(result)
- self.log.info("Table specs obtained.")
- project_id = self.project_id or hook.project_id
- if project_id:
- TranslationLegacyDatasetLink.persist(
- context=context,
- dataset_id=self.dataset_id,
- project_id=project_id,
- location=self.location,
- )
- return result
-
-
-@deprecated(
- planned_removal_date="September 30, 2025",
- use_instead="airflow.providers.google.cloud.operators.vertex_ai.dataset.ListDatasetsOperator, "
- "airflow.providers.google.cloud.operators.translate.TranslateDatasetsListOperator",
- category=AirflowProviderDeprecationWarning,
-)
-class AutoMLListDatasetOperator(GoogleCloudBaseOperator):
- """
- Lists AutoML Datasets in project.
-
- .. warning::
- AutoMLListDatasetOperator for tables, video intelligence, vision and natural language has been deprecated
- and no longer available. Please use
- :class:`airflow.providers.google.cloud.operators.vertex_ai.dataset.ListDatasetsOperator` instead.
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:AutoMLListDatasetOperator`
-
- :param project_id: ID of the Google Cloud project where datasets are located if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
- :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields: Sequence[str] = (
- "location",
- "project_id",
- "impersonation_chain",
- )
- operator_extra_links = (TranslationDatasetListLink(),)
-
- def __init__(
- self,
- *,
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- metadata: MetaData = (),
- timeout: float | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- gcp_conn_id: str = "google_cloud_default",
- impersonation_chain: str | Sequence[str] | None = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.location = location
- self.project_id = project_id
- self.metadata = metadata
- self.timeout = timeout
- self.retry = retry
- self.gcp_conn_id = gcp_conn_id
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: Context):
- hook = CloudAutoMLHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
- self.log.info("Requesting datasets")
- page_iterator = hook.list_datasets(
- location=self.location,
- project_id=self.project_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- result = []
- for dataset in page_iterator:
- result.append(Dataset.to_dict(dataset))
- self.log.info("Datasets obtained.")
-
- context["task_instance"].xcom_push(
- key="dataset_id_list",
- value=[hook.extract_object_id(d) for d in result],
- )
- project_id = self.project_id or hook.project_id
- if project_id:
- TranslationDatasetListLink.persist(context=context, project_id=project_id)
- return result
-
-
-@deprecated(
- planned_removal_date="September 30, 2025",
- use_instead="airflow.providers.google.cloud.operators.vertex_ai.dataset.ListDatasetsOperator, "
- "airflow.providers.google.cloud.operators.translate.TranslateDatasetsListOperator",
- category=AirflowProviderDeprecationWarning,
-)
-class AutoMLDeleteDatasetOperator(GoogleCloudBaseOperator):
- """
- Deletes a dataset and all of its contents.
-
- AutoMLDeleteDatasetOperator for tables, video intelligence, vision and natural language has been
- deprecated and no longer available. Please use
- :class:`airflow.providers.google.cloud.operators.vertex_ai.dataset.DeleteDatasetOperator` instead.
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:AutoMLDeleteDatasetOperator`
-
- :param dataset_id: Name of the dataset_id, list of dataset_id or string of dataset_id
- coma separated to be deleted.
- :param project_id: ID of the Google Cloud project where dataset is located if None then
- default project_id is used.
- :param location: The location of the project.
- :param retry: A retry object used to retry requests. If `None` is specified, requests will not be
- retried.
- :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if
- `retry` is specified, the timeout applies to each individual attempt.
- :param metadata: Additional metadata that is provided to the method.
- :param gcp_conn_id: The connection ID to use to connect to Google Cloud.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields: Sequence[str] = (
- "dataset_id",
- "location",
- "project_id",
- "impersonation_chain",
- )
-
- def __init__(
- self,
- *,
- dataset_id: str | list[str],
- location: str,
- project_id: str = PROVIDE_PROJECT_ID,
- metadata: MetaData = (),
- timeout: float | None = None,
- retry: Retry | _MethodDefault = DEFAULT,
- gcp_conn_id: str = "google_cloud_default",
- impersonation_chain: str | Sequence[str] | None = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
-
- self.dataset_id = dataset_id
- self.location = location
- self.project_id = project_id
- self.metadata = metadata
- self.timeout = timeout
- self.retry = retry
- self.gcp_conn_id = gcp_conn_id
- self.impersonation_chain = impersonation_chain
-
- @staticmethod
- def _parse_dataset_id(dataset_id: str | list[str]) -> list[str]:
- if not isinstance(dataset_id, str):
- return dataset_id
- try:
- return ast.literal_eval(dataset_id)
- except (SyntaxError, ValueError):
- return dataset_id.split(",")
-
- def execute(self, context: Context):
- hook = CloudAutoMLHook(
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
- hook.get_dataset(
- dataset_id=self.dataset_id,
- location=self.location,
- project_id=self.project_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- dataset_id_list = self._parse_dataset_id(self.dataset_id)
- for dataset_id in dataset_id_list:
- self.log.info("Deleting dataset %s", dataset_id)
- hook.delete_dataset(
- dataset_id=dataset_id,
- location=self.location,
- project_id=self.project_id,
- retry=self.retry,
- timeout=self.timeout,
- metadata=self.metadata,
- )
- self.log.info("Dataset deleted.")
diff --git a/providers/google/src/airflow/providers/google/cloud/operators/mlengine.py b/providers/google/src/airflow/providers/google/cloud/operators/mlengine.py
deleted file mode 100644
index 2fa2c75fed899..0000000000000
--- a/providers/google/src/airflow/providers/google/cloud/operators/mlengine.py
+++ /dev/null
@@ -1,111 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module contains Google Cloud MLEngine operators."""
-
-from __future__ import annotations
-
-import logging
-from collections.abc import Sequence
-from typing import TYPE_CHECKING
-
-from airflow.exceptions import AirflowProviderDeprecationWarning
-from airflow.providers.google.cloud.hooks.mlengine import MLEngineHook
-from airflow.providers.google.cloud.links.mlengine import MLEngineModelLink
-from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
-from airflow.providers.google.common.deprecated import deprecated
-from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
-
-if TYPE_CHECKING:
- from airflow.providers.common.compat.sdk import Context
-
-
-log = logging.getLogger(__name__)
-
-
-@deprecated(
- planned_removal_date="November 01, 2025",
- use_instead="appropriate VertexAI operator",
- reason="All the functionality of legacy MLEngine and new features are available on the Vertex AI.",
- category=AirflowProviderDeprecationWarning,
-)
-class MLEngineCreateModelOperator(GoogleCloudBaseOperator):
- """
- Creates a new model.
-
- .. warning::
- This operator is deprecated. Please use appropriate VertexAI operator from
- :class:`airflow.providers.google.cloud.operators.vertex_ai` instead.
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:MLEngineCreateModelOperator`
-
- The model should be provided by the `model` parameter.
-
- :param model: A dictionary containing the information about the model.
- :param project_id: The Google Cloud project name to which MLEngine model belongs.
- If set to None or missing, the default project_id from the Google Cloud connection is used.
- (templated)
- :param gcp_conn_id: The connection ID to use when fetching connection info.
- :param impersonation_chain: Optional service account to impersonate using short-term
- credentials, or chained list of accounts required to get the access_token
- of the last account in the list, which will be impersonated in the request.
- If set as a string, the account must grant the originating account
- the Service Account Token Creator IAM role.
- If set as a sequence, the identities from the list must grant
- Service Account Token Creator IAM role to the directly preceding identity, with first
- account from the list granting this role to the originating account (templated).
- """
-
- template_fields: Sequence[str] = (
- "project_id",
- "model",
- "impersonation_chain",
- )
- operator_extra_links = (MLEngineModelLink(),)
-
- def __init__(
- self,
- *,
- model: dict,
- project_id: str = PROVIDE_PROJECT_ID,
- gcp_conn_id: str = "google_cloud_default",
- impersonation_chain: str | Sequence[str] | None = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.project_id = project_id
- self.model = model
- self._gcp_conn_id = gcp_conn_id
- self.impersonation_chain = impersonation_chain
-
- def execute(self, context: Context):
- hook = MLEngineHook(
- gcp_conn_id=self._gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- )
-
- project_id = self.project_id or hook.project_id
- if project_id:
- MLEngineModelLink.persist(
- context=context,
- project_id=project_id,
- model_id=self.model["name"],
- )
-
- return hook.create_model(project_id=self.project_id, model=self.model)
diff --git a/providers/google/src/airflow/providers/google/get_provider_info.py b/providers/google/src/airflow/providers/google/get_provider_info.py
index 6f585a2278fc8..12d473d44a2e0 100644
--- a/providers/google/src/airflow/providers/google/get_provider_info.py
+++ b/providers/google/src/airflow/providers/google/get_provider_info.py
@@ -43,13 +43,6 @@ def get_provider_info():
"how-to-guide": ["/docs/apache-airflow-providers-google/operators/ads.rst"],
"tags": ["gmp"],
},
- {
- "integration-name": "Google AutoML",
- "external-doc-url": "https://cloud.google.com/automl/",
- "how-to-guide": ["/docs/apache-airflow-providers-google/operators/cloud/automl.rst"],
- "logo": "/docs/integration-logos/Cloud-AutoML.png",
- "tags": ["gcp"],
- },
{
"integration-name": "Google BigQuery Data Transfer Service",
"external-doc-url": "https://cloud.google.com/bigquery/transfer/",
@@ -495,10 +488,6 @@ def get_provider_info():
"integration-name": "Google Cloud AlloyDB",
"python-modules": ["airflow.providers.google.cloud.operators.alloy_db"],
},
- {
- "integration-name": "Google AutoML",
- "python-modules": ["airflow.providers.google.cloud.operators.automl"],
- },
{
"integration-name": "Google BigQuery",
"python-modules": ["airflow.providers.google.cloud.operators.bigquery"],
@@ -583,10 +572,6 @@ def get_provider_info():
"integration-name": "Google Kubernetes Engine",
"python-modules": ["airflow.providers.google.cloud.operators.kubernetes_engine"],
},
- {
- "integration-name": "Google Machine Learning Engine",
- "python-modules": ["airflow.providers.google.cloud.operators.mlengine"],
- },
{
"integration-name": "Google Cloud Natural Language",
"python-modules": ["airflow.providers.google.cloud.operators.natural_language"],
@@ -826,10 +811,6 @@ def get_provider_info():
],
"hooks": [
{"integration-name": "Google Ads", "python-modules": ["airflow.providers.google.ads.hooks.ads"]},
- {
- "integration-name": "Google AutoML",
- "python-modules": ["airflow.providers.google.cloud.hooks.automl"],
- },
{
"integration-name": "Google BigQuery",
"python-modules": ["airflow.providers.google.cloud.hooks.bigquery"],
diff --git a/providers/google/tests/deprecations_ignore.yml b/providers/google/tests/deprecations_ignore.yml
index 66a1508a6ec52..d07c6bdfc21a0 100644
--- a/providers/google/tests/deprecations_ignore.yml
+++ b/providers/google/tests/deprecations_ignore.yml
@@ -80,26 +80,6 @@
- providers/google/tests/unit/google/cloud/transfers/test_gcs_to_gcs.py::TestGoogleCloudStorageToCloudStorageOperator::test_wc_with_last_modified_time_with_all_true_cond
- providers/google/tests/unit/google/cloud/transfers/test_gcs_to_gcs.py::TestGoogleCloudStorageToCloudStorageOperator::test_wc_with_last_modified_time_with_one_true_cond
- providers/google/tests/unit/google/cloud/transfers/test_gcs_to_gcs.py::TestGoogleCloudStorageToCloudStorageOperator::test_wc_with_no_last_modified_time
-- providers/google/tests/unit/google/cloud/links/test_translate.py::TestTranslationLegacyDatasetLink::test_get_link
-- providers/google/tests/unit/google/cloud/links/test_translate.py::TestTranslationDatasetListLink::test_get_link
-- providers/google/tests/unit/google/cloud/links/test_translate.py::TestTranslationLegacyModelLink::test_get_link
-- providers/google/tests/unit/google/cloud/links/test_translate.py::TestTranslationLegacyModelTrainLink::test_get_link
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_get_conn
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_prediction_client
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_create_model
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_batch_predict
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_predict
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_create_dataset
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_import_dataset
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_list_column_specs
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_get_model
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_delete_model
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_update_dataset
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_deploy_model
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_list_table_specs
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_list_datasets
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_delete_dataset
-- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_get_dataset
- providers/google/tests/unit/google/cloud/hooks/test_datacatalog.py::TestCloudDataCatalog::test_lookup_entry_with_linked_resource
- providers/google/tests/unit/google/cloud/hooks/test_datacatalog.py::TestCloudDataCatalog::test_lookup_entry_with_sql_resource
- providers/google/tests/unit/google/cloud/hooks/test_datacatalog.py::TestCloudDataCatalog::test_lookup_entry_without_resource
diff --git a/providers/google/tests/unit/google/cloud/hooks/test_automl.py b/providers/google/tests/unit/google/cloud/hooks/test_automl.py
deleted file mode 100644
index ec49f23df9e2c..0000000000000
--- a/providers/google/tests/unit/google/cloud/hooks/test_automl.py
+++ /dev/null
@@ -1,257 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from unittest import mock
-
-from google.api_core.gapic_v1.method import DEFAULT
-from google.cloud.automl_v1beta1 import AutoMlClient
-
-from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
-from airflow.providers.google.common.consts import CLIENT_INFO
-
-from unit.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_no_default_project_id
-
-CREDENTIALS = "test-creds"
-TASK_ID = "test-automl-hook"
-GCP_PROJECT_ID = "test-project"
-GCP_LOCATION = "test-location"
-MODEL_NAME = "test_model"
-MODEL_ID = "projects/198907790164/locations/us-central1/models/TBL9195602771183665152"
-DATASET_ID = "TBL123456789"
-MODEL = {
- "display_name": MODEL_NAME,
- "dataset_id": DATASET_ID,
- "tables_model_metadata": {"train_budget_milli_node_hours": 1000},
-}
-
-LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}"
-MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}"
-DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}"
-
-INPUT_CONFIG = {"input": "value"}
-OUTPUT_CONFIG = {"output": "value"}
-PAYLOAD = {"test": "payload"}
-DATASET = {"dataset_id": "data"}
-MASK = {"field": "mask"}
-
-
-class TestAutoMLHook:
- def setup_method(self):
- with mock.patch(
- "airflow.providers.google.cloud.hooks.automl.GoogleBaseHook.__init__",
- new=mock_base_gcp_hook_no_default_project_id,
- ):
- self.hook = CloudAutoMLHook()
- self.hook.get_credentials = mock.MagicMock(return_value=CREDENTIALS)
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient")
- def test_get_conn(self, mock_automl_client):
- self.hook.get_conn()
- mock_automl_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO)
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient")
- def test_prediction_client(self, mock_prediction_client):
- client = self.hook.prediction_client # noqa: F841
- mock_prediction_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO)
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_model")
- def test_create_model(self, mock_create_model):
- self.hook.create_model(model=MODEL, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
-
- mock_create_model.assert_called_once_with(
- request=dict(parent=LOCATION_PATH, model=MODEL), retry=DEFAULT, timeout=None, metadata=()
- )
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.batch_predict")
- def test_batch_predict(self, mock_batch_predict):
- self.hook.batch_predict(
- model_id=MODEL_ID,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- input_config=INPUT_CONFIG,
- output_config=OUTPUT_CONFIG,
- )
-
- mock_batch_predict.assert_called_once_with(
- request=dict(
- name=MODEL_PATH, input_config=INPUT_CONFIG, output_config=OUTPUT_CONFIG, params=None
- ),
- retry=DEFAULT,
- timeout=None,
- metadata=(),
- )
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.predict")
- def test_predict(self, mock_predict):
- self.hook.predict(
- model_id=MODEL_ID,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- payload=PAYLOAD,
- )
-
- mock_predict.assert_called_once_with(
- request=dict(name=MODEL_PATH, payload=PAYLOAD, params=None),
- retry=DEFAULT,
- timeout=None,
- metadata=(),
- )
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_dataset")
- def test_create_dataset(self, mock_create_dataset):
- self.hook.create_dataset(dataset=DATASET, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
-
- mock_create_dataset.assert_called_once_with(
- request=dict(parent=LOCATION_PATH, dataset=DATASET),
- retry=DEFAULT,
- timeout=None,
- metadata=(),
- )
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.import_data")
- def test_import_dataset(self, mock_import_data):
- self.hook.import_data(
- dataset_id=DATASET_ID,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- input_config=INPUT_CONFIG,
- )
-
- mock_import_data.assert_called_once_with(
- request=dict(name=DATASET_PATH, input_config=INPUT_CONFIG),
- retry=DEFAULT,
- timeout=None,
- metadata=(),
- )
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_column_specs")
- def test_list_column_specs(self, mock_list_column_specs):
- table_spec = "table_spec_id"
- filter_ = "filter"
- page_size = 42
-
- self.hook.list_column_specs(
- dataset_id=DATASET_ID,
- table_spec_id=table_spec,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- field_mask=MASK,
- filter_=filter_,
- page_size=page_size,
- )
-
- parent = AutoMlClient.table_spec_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID, table_spec)
- mock_list_column_specs.assert_called_once_with(
- request=dict(parent=parent, field_mask=MASK, filter=filter_, page_size=page_size),
- retry=DEFAULT,
- timeout=None,
- metadata=(),
- )
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.get_model")
- def test_get_model(self, mock_get_model):
- self.hook.get_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
-
- mock_get_model.assert_called_once_with(
- request=dict(name=MODEL_PATH), retry=DEFAULT, timeout=None, metadata=()
- )
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_model")
- def test_delete_model(self, mock_delete_model):
- self.hook.delete_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
-
- mock_delete_model.assert_called_once_with(
- request=dict(name=MODEL_PATH), retry=DEFAULT, timeout=None, metadata=()
- )
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.update_dataset")
- def test_update_dataset(self, mock_update_dataset):
- self.hook.update_dataset(
- dataset=DATASET,
- update_mask=MASK,
- )
-
- mock_update_dataset.assert_called_once_with(
- request=dict(dataset=DATASET, update_mask=MASK), retry=DEFAULT, timeout=None, metadata=()
- )
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.deploy_model")
- def test_deploy_model(self, mock_deploy_model):
- image_detection_metadata = {}
-
- self.hook.deploy_model(
- model_id=MODEL_ID,
- image_detection_metadata=image_detection_metadata,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- )
-
- mock_deploy_model.assert_called_once_with(
- request=dict(
- name=MODEL_PATH,
- image_object_detection_model_deployment_metadata=image_detection_metadata,
- ),
- retry=DEFAULT,
- timeout=None,
- metadata=(),
- )
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_table_specs")
- def test_list_table_specs(self, mock_list_table_specs):
- filter_ = "filter"
- page_size = 42
-
- self.hook.list_table_specs(
- dataset_id=DATASET_ID,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- filter_=filter_,
- page_size=page_size,
- )
-
- mock_list_table_specs.assert_called_once_with(
- request=dict(parent=DATASET_PATH, filter=filter_, page_size=page_size),
- retry=DEFAULT,
- timeout=None,
- metadata=(),
- )
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_datasets")
- def test_list_datasets(self, mock_list_datasets):
- self.hook.list_datasets(location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
-
- mock_list_datasets.assert_called_once_with(
- request=dict(parent=LOCATION_PATH), retry=DEFAULT, timeout=None, metadata=()
- )
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_dataset")
- def test_delete_dataset(self, mock_delete_dataset):
- self.hook.delete_dataset(dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
-
- mock_delete_dataset.assert_called_once_with(
- request=dict(name=DATASET_PATH), retry=DEFAULT, timeout=None, metadata=()
- )
-
- @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.get_dataset")
- def test_get_dataset(self, mock_get_dataset):
- self.hook.get_dataset(dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID)
-
- mock_get_dataset.assert_called_once_with(
- request=dict(name=DATASET_PATH), retry=DEFAULT, timeout=None, metadata=()
- )
diff --git a/providers/google/tests/unit/google/cloud/links/test_translate.py b/providers/google/tests/unit/google/cloud/links/test_translate.py
deleted file mode 100644
index 0adc28e53e218..0000000000000
--- a/providers/google/tests/unit/google/cloud/links/test_translate.py
+++ /dev/null
@@ -1,198 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from unittest import mock
-
-import pytest
-
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
-
-# For no Pydantic environment, we need to skip the tests
-pytest.importorskip("google.cloud.aiplatform_v1")
-
-from airflow.providers.google.cloud.links.translate import (
- TRANSLATION_BASE_LINK,
- TranslationDatasetListLink,
- TranslationLegacyDatasetLink,
- TranslationLegacyModelLink,
- TranslationLegacyModelTrainLink,
-)
-from airflow.providers.google.cloud.operators.automl import (
- AutoMLCreateDatasetOperator,
- AutoMLListDatasetOperator,
- AutoMLTrainModelOperator,
-)
-
-if AIRFLOW_V_3_0_PLUS:
- from airflow.sdk.execution_time.comms import XComResult
-
-GCP_LOCATION = "test-location"
-GCP_PROJECT_ID = "test-project"
-DATASET = "test-dataset"
-MODEL = "test-model"
-
-
-class TestTranslationLegacyDatasetLink:
- @pytest.mark.db_test
- def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms):
- expected_url = f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/sentences?project={GCP_PROJECT_ID}"
- link = TranslationLegacyDatasetLink()
- ti = create_task_instance_of_operator(
- AutoMLCreateDatasetOperator,
- dag_id="test_legacy_dataset_link_dag",
- task_id="test_legacy_dataset_link_task",
- dataset=DATASET,
- location=GCP_LOCATION,
- )
- session.add(ti)
- session.commit()
-
- link.persist(
- context={"ti": ti, "task": ti.task},
- dataset_id=DATASET,
- project_id=GCP_PROJECT_ID,
- location=GCP_LOCATION,
- )
-
- if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms:
- mock_supervisor_comms.send.return_value = XComResult(
- key="key",
- value={
- "location": ti.task.location,
- "dataset_id": DATASET,
- "project_id": GCP_PROJECT_ID,
- },
- )
- actual_url = link.get_link(operator=ti.task, ti_key=ti.key)
- assert actual_url == expected_url
-
-
-class TestTranslationDatasetListLink:
- @pytest.mark.db_test
- def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms):
- expected_url = f"{TRANSLATION_BASE_LINK}/datasets?project={GCP_PROJECT_ID}"
- link = TranslationDatasetListLink()
- ti = create_task_instance_of_operator(
- AutoMLListDatasetOperator,
- dag_id="test_dataset_list_link_dag",
- task_id="test_dataset_list_link_task",
- location=GCP_LOCATION,
- )
- session.add(ti)
- session.commit()
- link.persist(context={"ti": ti, "task": ti.task}, project_id=GCP_PROJECT_ID)
-
- if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms:
- mock_supervisor_comms.send.return_value = XComResult(
- key="key",
- value={
- "project_id": GCP_PROJECT_ID,
- },
- )
- actual_url = link.get_link(operator=ti.task, ti_key=ti.key)
- assert actual_url == expected_url
-
-
-class TestTranslationLegacyModelLink:
- @pytest.mark.db_test
- def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms):
- expected_url = (
- f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/"
- f"evaluate;modelId={MODEL}?project={GCP_PROJECT_ID}"
- )
- link = TranslationLegacyModelLink()
- ti = create_task_instance_of_operator(
- AutoMLTrainModelOperator,
- dag_id="test_legacy_model_link_dag",
- task_id="test_legacy_model_link_task",
- model=MODEL,
- project_id=GCP_PROJECT_ID,
- location=GCP_LOCATION,
- )
- session.add(ti)
- session.commit()
- task = mock.MagicMock()
- task.extra_links_params = {
- "dataset_id": DATASET,
- "model_id": MODEL,
- "project_id": GCP_PROJECT_ID,
- "location": GCP_LOCATION,
- }
- link.persist(
- context={"ti": ti, "task": task},
- dataset_id=DATASET,
- model_id=MODEL,
- project_id=GCP_PROJECT_ID,
- )
- if mock_supervisor_comms:
- mock_supervisor_comms.send.return_value = XComResult(
- key="key",
- value={
- "location": ti.task.location,
- "dataset_id": DATASET,
- "model_id": MODEL,
- "project_id": GCP_PROJECT_ID,
- },
- )
- actual_url = link.get_link(operator=task, ti_key=ti.key)
- assert actual_url == expected_url
-
-
-class TestTranslationLegacyModelTrainLink:
- @pytest.mark.db_test
- def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms):
- expected_url = (
- f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/"
- f"train?project={GCP_PROJECT_ID}"
- )
- link = TranslationLegacyModelTrainLink()
- ti = create_task_instance_of_operator(
- AutoMLTrainModelOperator,
- dag_id="test_legacy_model_train_link_dag",
- task_id="test_legacy_model_train_link_task",
- model={"dataset_id": DATASET},
- project_id=GCP_PROJECT_ID,
- location=GCP_LOCATION,
- )
- session.add(ti)
- session.commit()
- task = mock.MagicMock()
- task.extra_links_params = {
- "dataset_id": DATASET,
- "project_id": GCP_PROJECT_ID,
- "location": GCP_LOCATION,
- }
- link.persist(
- context={"ti": ti, "task": task},
- dataset_id=DATASET,
- project_id=GCP_PROJECT_ID,
- location=GCP_LOCATION,
- )
-
- if mock_supervisor_comms:
- mock_supervisor_comms.send.return_value = XComResult(
- key="key",
- value={
- "location": ti.task.location,
- "dataset_id": ti.task.model["dataset_id"],
- "project_id": GCP_PROJECT_ID,
- },
- )
- actual_url = link.get_link(operator=task, ti_key=ti.key)
- assert actual_url == expected_url
diff --git a/providers/google/tests/unit/google/cloud/operators/test_automl.py b/providers/google/tests/unit/google/cloud/operators/test_automl.py
deleted file mode 100644
index 18f4de4af709d..0000000000000
--- a/providers/google/tests/unit/google/cloud/operators/test_automl.py
+++ /dev/null
@@ -1,616 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-import copy
-from unittest import mock
-
-import pytest
-
-# For no Pydantic environment, we need to skip the tests
-pytest.importorskip("google.cloud.aiplatform_v1")
-
-from google.api_core.gapic_v1.method import DEFAULT
-from google.cloud.automl_v1beta1 import Dataset, Model, PredictResponse
-
-from airflow.exceptions import AirflowProviderDeprecationWarning
-from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook
-from airflow.providers.google.cloud.hooks.vertex_ai.prediction_service import PredictionServiceHook
-from airflow.providers.google.cloud.operators.automl import (
- AutoMLCreateDatasetOperator,
- AutoMLDeleteDatasetOperator,
- AutoMLDeleteModelOperator,
- AutoMLDeployModelOperator,
- AutoMLGetModelOperator,
- AutoMLImportDataOperator,
- AutoMLListDatasetOperator,
- AutoMLPredictOperator,
- AutoMLTablesListColumnSpecsOperator,
- AutoMLTablesListTableSpecsOperator,
- AutoMLTablesUpdateDatasetOperator,
- AutoMLTrainModelOperator,
-)
-
-CREDENTIALS = "test-creds"
-TASK_ID = "test-automl-hook"
-GCP_PROJECT_ID = "test-project"
-GCP_LOCATION = "test-location"
-MODEL_NAME = "test_model"
-MODEL_ID = "TBL9195602771183665152"
-DATASET_ID = "TBL123456789"
-MODEL = {
- "display_name": MODEL_NAME,
- "dataset_id": DATASET_ID,
- "translation_model_metadata": {"train_budget_milli_node_hours": 1000},
-}
-MODEL_DEPRECATED = {
- "display_name": MODEL_NAME,
- "dataset_id": DATASET_ID,
- "tables_model_metadata": {"train_budget_milli_node_hours": 1000},
-}
-
-LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}"
-MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}"
-DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}"
-
-INPUT_CONFIG = {"input": "value"}
-OUTPUT_CONFIG = {"output": "value"}
-PAYLOAD = {"test": "payload"}
-DATASET = {"dataset_id": "data", "translation_dataset_metadata": "data"}
-DATASET_DEPRECATED = {"tables_model_metadata": "data"}
-MASK = {"field": "mask"}
-
-extract_object_id = CloudAutoMLHook.extract_object_id
-
-
-class TestAutoMLTrainModelOperator:
- @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
- def test_execute(self, mock_hook):
- mock_hook.return_value.create_model.return_value.result.return_value = Model(name=MODEL_PATH)
- mock_hook.return_value.extract_object_id = extract_object_id
- mock_hook.return_value.wait_for_operation.return_value = Model()
- with pytest.warns(AirflowProviderDeprecationWarning):
- op = AutoMLTrainModelOperator(
- model=MODEL,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- task_id=TASK_ID,
- )
- op.execute(context=mock.MagicMock())
-
- mock_hook.return_value.create_model.assert_called_once_with(
- model=MODEL,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- retry=DEFAULT,
- timeout=None,
- metadata=(),
- )
-
- @pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator, session):
- with pytest.warns(AirflowProviderDeprecationWarning):
- ti = create_task_instance_of_operator(
- AutoMLTrainModelOperator,
- # Templated fields
- model="{{ 'model' }}",
- location="{{ 'location' }}",
- impersonation_chain="{{ 'impersonation_chain' }}",
- # Other parameters
- dag_id="test_template_body_templating_dag",
- task_id="test_template_body_templating_task",
- )
- session.add(ti)
- session.commit()
- ti.render_templates()
- task: AutoMLTrainModelOperator = ti.task
- assert task.model == "model"
- assert task.location == "location"
- assert task.impersonation_chain == "impersonation_chain"
-
-
-class TestAutoMLPredictOperator:
- @mock.patch("airflow.providers.google.cloud.links.translate.TranslationLegacyModelPredictLink.persist")
- @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
- def test_execute(self, mock_hook, mock_link_persist):
- mock_hook.return_value.predict.return_value = PredictResponse()
- mock_hook.return_value.get_model.return_value = mock.MagicMock(**MODEL)
- mock_context = {"ti": mock.MagicMock()}
- with pytest.warns(AirflowProviderDeprecationWarning):
- op = AutoMLPredictOperator(
- model_id=MODEL_ID,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- payload=PAYLOAD,
- task_id=TASK_ID,
- operation_params={"TEST_KEY": "TEST_VALUE"},
- )
- op.execute(context=mock_context)
- mock_hook.return_value.predict.assert_called_once_with(
- location=GCP_LOCATION,
- metadata=(),
- model_id=MODEL_ID,
- params={"TEST_KEY": "TEST_VALUE"},
- payload=PAYLOAD,
- project_id=GCP_PROJECT_ID,
- retry=DEFAULT,
- timeout=None,
- )
- mock_link_persist.assert_called_once_with(
- context=mock_context,
- model_id=MODEL_ID,
- project_id=GCP_PROJECT_ID,
- location=GCP_LOCATION,
- dataset_id=DATASET_ID,
- )
-
- @pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator, session):
- with pytest.warns(AirflowProviderDeprecationWarning):
- ti = create_task_instance_of_operator(
- AutoMLPredictOperator,
- # Templated fields
- model_id="{{ 'model-id' }}",
- location="{{ 'location' }}",
- project_id="{{ 'project-id' }}",
- impersonation_chain="{{ 'impersonation-chain' }}",
- # Other parameters
- dag_id="test_template_body_templating_dag",
- task_id="test_template_body_templating_task",
- payload={},
- )
- session.add(ti)
- session.commit()
- ti.render_templates()
- task: AutoMLPredictOperator = ti.task
- assert task.model_id == "model-id"
- assert task.project_id == "project-id"
- assert task.location == "location"
- assert task.impersonation_chain == "impersonation-chain"
-
- @pytest.mark.db_test
- def test_hook_type(self):
- with pytest.warns(AirflowProviderDeprecationWarning):
- op = AutoMLPredictOperator(
- model_id=MODEL_ID,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- payload=PAYLOAD,
- task_id=TASK_ID,
- operation_params={"TEST_KEY": "TEST_VALUE"},
- )
- with pytest.warns(AirflowProviderDeprecationWarning):
- assert isinstance(op.hook, CloudAutoMLHook)
- with pytest.warns(AirflowProviderDeprecationWarning):
- op = AutoMLPredictOperator(
- endpoint_id="endpoint_id",
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- payload=PAYLOAD,
- task_id=TASK_ID,
- operation_params={"TEST_KEY": "TEST_VALUE"},
- )
- assert isinstance(op.hook, PredictionServiceHook)
-
-
-class TestAutoMLCreateImportOperator:
- @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
- def test_execute(self, mock_hook):
- mock_hook.return_value.create_dataset.return_value = Dataset(name=DATASET_PATH)
- mock_hook.return_value.extract_object_id = extract_object_id
- with pytest.warns(AirflowProviderDeprecationWarning):
- op = AutoMLCreateDatasetOperator(
- dataset=DATASET,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- task_id=TASK_ID,
- )
- op.execute(context=mock.MagicMock())
-
- mock_hook.return_value.create_dataset.assert_called_once_with(
- dataset=DATASET,
- location=GCP_LOCATION,
- metadata=(),
- project_id=GCP_PROJECT_ID,
- retry=DEFAULT,
- timeout=None,
- )
-
- @pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator, session):
- with pytest.warns(AirflowProviderDeprecationWarning):
- ti = create_task_instance_of_operator(
- AutoMLCreateDatasetOperator,
- # Templated fields
- dataset="{{ 'dataset' }}",
- location="{{ 'location' }}",
- project_id="{{ 'project-id' }}",
- impersonation_chain="{{ 'impersonation-chain' }}",
- # Other parameters
- dag_id="test_template_body_templating_dag",
- task_id="test_template_body_templating_task",
- )
- session.add(ti)
- session.commit()
- ti.render_templates()
- task: AutoMLCreateDatasetOperator = ti.task
- assert task.dataset == "dataset"
- assert task.project_id == "project-id"
- assert task.location == "location"
- assert task.impersonation_chain == "impersonation-chain"
-
-
-class TestAutoMLTablesListColumnsSpecsOperator:
- @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
- def test_execute(self, mock_hook):
- table_spec = "table_spec_id"
- filter_ = "filter"
- page_size = 42
-
- with pytest.warns(AirflowProviderDeprecationWarning):
- AutoMLTablesListColumnSpecsOperator(
- dataset_id=DATASET_ID,
- table_spec_id=table_spec,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- field_mask=MASK,
- filter_=filter_,
- page_size=page_size,
- task_id=TASK_ID,
- )
- mock_hook.assert_not_called()
-
- @pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator):
- with pytest.warns(AirflowProviderDeprecationWarning):
- create_task_instance_of_operator(
- AutoMLTablesListColumnSpecsOperator,
- # Templated fields
- dataset_id="{{ 'dataset-id' }}",
- table_spec_id="{{ 'table-spec-id' }}",
- field_mask="{{ 'field-mask' }}",
- filter_="{{ 'filter-' }}",
- location="{{ 'location' }}",
- project_id="{{ 'project-id' }}",
- impersonation_chain="{{ 'impersonation-chain' }}",
- # Other parameters
- dag_id="test_template_body_templating_dag",
- task_id="test_template_body_templating_task",
- )
-
-
-class TestAutoMLTablesUpdateDatasetOperator:
- @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
- def test_execute(self, mock_hook):
- mock_hook.return_value.update_dataset.return_value = Dataset(name=DATASET_PATH)
- dataset = copy.deepcopy(DATASET)
- dataset["name"] = DATASET_ID
-
- with pytest.warns(AirflowProviderDeprecationWarning):
- AutoMLTablesUpdateDatasetOperator(
- dataset=dataset,
- update_mask=MASK,
- location=GCP_LOCATION,
- task_id=TASK_ID,
- )
- mock_hook.assert_not_called()
-
- @pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator):
- with pytest.warns(AirflowProviderDeprecationWarning):
- create_task_instance_of_operator(
- AutoMLTablesUpdateDatasetOperator,
- # Templated fields
- dataset="{{ 'dataset' }}",
- update_mask="{{ 'update-mask' }}",
- location="{{ 'location' }}",
- impersonation_chain="{{ 'impersonation-chain' }}",
- # Other parameters
- dag_id="test_template_body_templating_dag",
- task_id="test_template_body_templating_task",
- )
-
-
-class TestAutoMLGetModelOperator:
- @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
- def test_execute(self, mock_hook):
- mock_hook.return_value.get_model.return_value = Model(name=MODEL_PATH)
- mock_hook.return_value.extract_object_id = extract_object_id
- with pytest.warns(AirflowProviderDeprecationWarning):
- op = AutoMLGetModelOperator(
- model_id=MODEL_ID,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- task_id=TASK_ID,
- )
-
- op.execute(context=mock.MagicMock())
-
- mock_hook.return_value.get_model.assert_called_once_with(
- location=GCP_LOCATION,
- metadata=(),
- model_id=MODEL_ID,
- project_id=GCP_PROJECT_ID,
- retry=DEFAULT,
- timeout=None,
- )
-
- @pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator, session):
- with pytest.warns(AirflowProviderDeprecationWarning):
- ti = create_task_instance_of_operator(
- AutoMLGetModelOperator,
- # Templated fields
- model_id="{{ 'model-id' }}",
- location="{{ 'location' }}",
- project_id="{{ 'project-id' }}",
- impersonation_chain="{{ 'impersonation-chain' }}",
- # Other parameters
- dag_id="test_template_body_templating_dag",
- task_id="test_template_body_templating_task",
- )
- session.add(ti)
- session.commit()
- ti.render_templates()
- task: AutoMLGetModelOperator = ti.task
- assert task.model_id == "model-id"
- assert task.location == "location"
- assert task.project_id == "project-id"
- assert task.impersonation_chain == "impersonation-chain"
-
-
-class TestAutoMLDeleteModelOperator:
- @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
- def test_execute(self, mock_hook):
- with pytest.warns(AirflowProviderDeprecationWarning):
- op = AutoMLDeleteModelOperator(
- model_id=MODEL_ID,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- task_id=TASK_ID,
- )
- op.execute(context=None)
- mock_hook.return_value.delete_model.assert_called_once_with(
- location=GCP_LOCATION,
- metadata=(),
- model_id=MODEL_ID,
- project_id=GCP_PROJECT_ID,
- retry=DEFAULT,
- timeout=None,
- )
-
- @pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator, session):
- with pytest.warns(AirflowProviderDeprecationWarning):
- ti = create_task_instance_of_operator(
- AutoMLDeleteModelOperator,
- # Templated fields
- model_id="{{ 'model-id' }}",
- location="{{ 'location' }}",
- project_id="{{ 'project-id' }}",
- impersonation_chain="{{ 'impersonation-chain' }}",
- # Other parameters
- dag_id="test_template_body_templating_dag",
- task_id="test_template_body_templating_task",
- )
- session.add(ti)
- session.commit()
- ti.render_templates()
- task: AutoMLDeleteModelOperator = ti.task
- assert task.model_id == "model-id"
- assert task.location == "location"
- assert task.project_id == "project-id"
- assert task.impersonation_chain == "impersonation-chain"
-
-
-class TestAutoMLDeployModelOperator:
- @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
- def test_execute(self, mock_hook):
- image_detection_metadata = {}
-
- with pytest.warns(AirflowProviderDeprecationWarning):
- AutoMLDeployModelOperator(
- model_id=MODEL_ID,
- image_detection_metadata=image_detection_metadata,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- task_id=TASK_ID,
- )
-
- mock_hook.assert_not_called()
-
- @pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator):
- with pytest.warns(AirflowProviderDeprecationWarning):
- create_task_instance_of_operator(
- AutoMLDeployModelOperator,
- # Templated fields
- model_id="{{ 'model-id' }}",
- location="{{ 'location' }}",
- project_id="{{ 'project-id' }}",
- impersonation_chain="{{ 'impersonation-chain' }}",
- # Other parameters
- dag_id="test_template_body_templating_dag",
- task_id="test_template_body_templating_task",
- )
-
-
-class TestAutoMLDatasetImportOperator:
- @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
- def test_execute(self, mock_hook):
- with pytest.warns(AirflowProviderDeprecationWarning):
- op = AutoMLImportDataOperator(
- dataset_id=DATASET_ID,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- input_config=INPUT_CONFIG,
- task_id=TASK_ID,
- )
-
- op.execute(context=mock.MagicMock())
-
- mock_hook.return_value.import_data.assert_called_once_with(
- input_config=INPUT_CONFIG,
- location=GCP_LOCATION,
- metadata=(),
- dataset_id=DATASET_ID,
- project_id=GCP_PROJECT_ID,
- retry=DEFAULT,
- timeout=None,
- )
-
- @pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator, session):
- with pytest.warns(AirflowProviderDeprecationWarning):
- ti = create_task_instance_of_operator(
- AutoMLImportDataOperator,
- # Templated fields
- dataset_id="{{ 'dataset-id' }}",
- input_config="{{ 'input-config' }}",
- location="{{ 'location' }}",
- project_id="{{ 'project-id' }}",
- impersonation_chain="{{ 'impersonation-chain' }}",
- # Other parameters
- dag_id="test_template_body_templating_dag",
- task_id="test_template_body_templating_task",
- )
- session.add(ti)
- session.commit()
- ti.render_templates()
- task: AutoMLImportDataOperator = ti.task
- assert task.dataset_id == "dataset-id"
- assert task.input_config == "input-config"
- assert task.location == "location"
- assert task.project_id == "project-id"
- assert task.impersonation_chain == "impersonation-chain"
-
-
-class TestAutoMLTablesListTableSpecsOperator:
- @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
- def test_execute(self, mock_hook):
- filter_ = "filter"
- page_size = 42
-
- with pytest.warns(AirflowProviderDeprecationWarning):
- AutoMLTablesListTableSpecsOperator(
- dataset_id=DATASET_ID,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- filter_=filter_,
- page_size=page_size,
- task_id=TASK_ID,
- )
- mock_hook.assert_not_called()
-
- @pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator):
- with pytest.warns(AirflowProviderDeprecationWarning):
- create_task_instance_of_operator(
- AutoMLTablesListTableSpecsOperator,
- # Templated fields
- dataset_id="{{ 'dataset-id' }}",
- filter_="{{ 'filter-' }}",
- location="{{ 'location' }}",
- project_id="{{ 'project-id' }}",
- impersonation_chain="{{ 'impersonation-chain' }}",
- # Other parameters
- dag_id="test_template_body_templating_dag",
- task_id="test_template_body_templating_task",
- )
-
-
-class TestAutoMLDatasetListOperator:
- @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
- def test_execute(self, mock_hook):
- with pytest.warns(AirflowProviderDeprecationWarning):
- op = AutoMLListDatasetOperator(location=GCP_LOCATION, project_id=GCP_PROJECT_ID, task_id=TASK_ID)
-
- op.execute(context=mock.MagicMock())
-
- mock_hook.return_value.list_datasets.assert_called_once_with(
- location=GCP_LOCATION,
- metadata=(),
- project_id=GCP_PROJECT_ID,
- retry=DEFAULT,
- timeout=None,
- )
-
- @pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator, session):
- with pytest.warns(AirflowProviderDeprecationWarning):
- ti = create_task_instance_of_operator(
- AutoMLListDatasetOperator,
- # Templated fields
- location="{{ 'location' }}",
- project_id="{{ 'project-id' }}",
- impersonation_chain="{{ 'impersonation-chain' }}",
- # Other parameters
- dag_id="test_template_body_templating_dag",
- task_id="test_template_body_templating_task",
- )
- session.add(ti)
- session.commit()
- ti.render_templates()
- task: AutoMLListDatasetOperator = ti.task
- assert task.location == "location"
- assert task.project_id == "project-id"
- assert task.impersonation_chain == "impersonation-chain"
-
-
-class TestAutoMLDatasetDeleteOperator:
- @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook")
- def test_execute(self, mock_hook):
- with pytest.warns(AirflowProviderDeprecationWarning):
- op = AutoMLDeleteDatasetOperator(
- dataset_id=DATASET_ID,
- location=GCP_LOCATION,
- project_id=GCP_PROJECT_ID,
- task_id=TASK_ID,
- )
- op.execute(context=None)
- mock_hook.return_value.delete_dataset.assert_called_once_with(
- location=GCP_LOCATION,
- dataset_id=DATASET_ID,
- metadata=(),
- project_id=GCP_PROJECT_ID,
- retry=DEFAULT,
- timeout=None,
- )
-
- @pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator, session):
- with pytest.warns(AirflowProviderDeprecationWarning):
- ti = create_task_instance_of_operator(
- AutoMLDeleteDatasetOperator,
- # Templated fields
- dataset_id="{{ 'dataset-id' }}",
- location="{{ 'location' }}",
- project_id="{{ 'project-id' }}",
- impersonation_chain="{{ 'impersonation-chain' }}",
- # Other parameters
- dag_id="test_template_body_templating_dag",
- task_id="test_template_body_templating_task",
- )
- session.add(ti)
- session.commit()
- ti.render_templates()
- task: AutoMLDeleteDatasetOperator = ti.task
- assert task.dataset_id == "dataset-id"
- assert task.location == "location"
- assert task.project_id == "project-id"
- assert task.impersonation_chain == "impersonation-chain"
diff --git a/providers/google/tests/unit/google/cloud/operators/test_mlengine.py b/providers/google/tests/unit/google/cloud/operators/test_mlengine.py
deleted file mode 100644
index 5ca6cf8cea73d..0000000000000
--- a/providers/google/tests/unit/google/cloud/operators/test_mlengine.py
+++ /dev/null
@@ -1,77 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from __future__ import annotations
-
-from unittest.mock import MagicMock, patch
-
-import pytest
-
-from airflow.exceptions import AirflowProviderDeprecationWarning
-from airflow.providers.google.cloud.operators.mlengine import MLEngineCreateModelOperator
-
-TEST_PROJECT_ID = "test-project-id"
-TEST_MODEL_NAME = "test-model-name"
-TEST_GCP_CONN_ID = "test-gcp-conn-id"
-TEST_IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"]
-TEST_MODEL = {
- "name": TEST_MODEL_NAME,
-}
-MLENGINE_AI_PATH = "airflow.providers.google.cloud.operators.mlengine.{}"
-
-
-class TestMLEngineCreateModelOperator:
- @patch(MLENGINE_AI_PATH.format("MLEngineHook"))
- def test_success_create_model(self, mock_hook):
- with pytest.warns(AirflowProviderDeprecationWarning):
- task = MLEngineCreateModelOperator(
- task_id="task-id",
- project_id=TEST_PROJECT_ID,
- model=TEST_MODEL,
- gcp_conn_id=TEST_GCP_CONN_ID,
- impersonation_chain=TEST_IMPERSONATION_CHAIN,
- )
-
- task.execute(context=MagicMock())
-
- mock_hook.assert_called_once_with(
- gcp_conn_id=TEST_GCP_CONN_ID,
- impersonation_chain=TEST_IMPERSONATION_CHAIN,
- )
- mock_hook.return_value.create_model.assert_called_once_with(
- project_id=TEST_PROJECT_ID, model=TEST_MODEL
- )
-
- @pytest.mark.db_test
- def test_templating(self, create_task_instance_of_operator, session):
- with pytest.warns(AirflowProviderDeprecationWarning):
- ti = create_task_instance_of_operator(
- MLEngineCreateModelOperator,
- # Templated fields
- project_id="{{ 'project_id' }}",
- model="{{ 'model' }}",
- impersonation_chain="{{ 'impersonation_chain' }}",
- # Other parameters
- dag_id="test_template_body_templating_dag",
- task_id="test_template_body_templating_task",
- )
- session.add(ti)
- session.commit()
- ti.render_templates()
- task: MLEngineCreateModelOperator = ti.task
- assert task.project_id == "project_id"
- assert task.model == "model"
- assert task.impersonation_chain == "impersonation_chain"