diff --git a/airflow-core/tests/unit/always/test_project_structure.py b/airflow-core/tests/unit/always/test_project_structure.py index 8da25c042c036..34abb250b9f92 100644 --- a/airflow-core/tests/unit/always/test_project_structure.py +++ b/airflow-core/tests/unit/always/test_project_structure.py @@ -136,7 +136,6 @@ def test_providers_modules_should_have_tests(self): "providers/fab/tests/unit/fab/www/test_session.py", "providers/fab/tests/unit/fab/www/test_views.py", "providers/google/tests/unit/google/cloud/fs/test_gcs.py", - "providers/google/tests/unit/google/cloud/links/test_automl.py", "providers/google/tests/unit/google/cloud/links/test_base.py", "providers/google/tests/unit/google/cloud/links/test_bigquery.py", "providers/google/tests/unit/google/cloud/links/test_bigquery_dts.py", @@ -162,6 +161,7 @@ def test_providers_modules_should_have_tests(self): "providers/google/tests/unit/google/cloud/links/test_spanner.py", "providers/google/tests/unit/google/cloud/links/test_stackdriver.py", "providers/google/tests/unit/google/cloud/links/test_workflows.py", + "providers/google/tests/unit/google/cloud/links/test_translate.py", "providers/google/tests/unit/google/cloud/operators/vertex_ai/test_auto_ml.py", "providers/google/tests/unit/google/cloud/operators/vertex_ai/test_batch_prediction_job.py", "providers/google/tests/unit/google/cloud/operators/vertex_ai/test_custom_job.py", diff --git a/providers/google/docs/changelog.rst b/providers/google/docs/changelog.rst index a61e04d64c271..d76bccc144645 100644 --- a/providers/google/docs/changelog.rst +++ b/providers/google/docs/changelog.rst @@ -27,6 +27,30 @@ Changelog --------- +.. warning:: + Deprecated classes, parameters and features have been removed from the Google provider package. + The following breaking changes were introduced: + +* Operators + + * ``Remove AutoMLTrainModelOperator use airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTabularTrainingJobOperator, airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLVideoTrainingJobOperator, airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLImageTrainingJobOperator, airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator, airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator instead`` + * ``Remove AutoMLPredictOperator use airflow.providers.google.cloud.operators.translate.TranslateTextOperator instead`` + * ``Remove AutoMLCreateDatasetOperator use airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator, airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator instead`` + * ``Remove AutoMLImportDataOperator use airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator, airflow.providers.google.cloud.operators.translate.TranslateImportDataOperator instead`` + * ``Remove AutoMLTablesListColumnSpecsOperator because of the shutdown of legacy version of AutoML Tables`` + * ``Remove AutoMLTablesUpdateDatasetOperator use airflow.providers.google.cloud.operators.vertex_ai.dataset.UpdateDatasetOperator instead`` + * ``Remove AutoMLGetModelOperator use airflow.providers.google.cloud.operators.vertex_ai.model_service.GetModelOperator instead`` + * ``Remove AutoMLDeleteModelOperator use airflow.providers.google.cloud.operators.vertex_ai.model_service.DeleteModelOperator, airflow.providers.google.cloud.operators.translate.TranslateDeleteModelOperator instead`` + * ``Remove AutoMLDeployModelOperator use airflow.providers.google.cloud.operators.vertex_ai.endpoint_service.DeployModelOperator instead`` + * ``Remove AutoMLTablesListTableSpecsOperator because of the shutdown of legacy version of AutoML Tables`` + * ``Remove AutoMLListDatasetOperator use airflow.providers.google.cloud.operators.vertex_ai.dataset.ListDatasetsOperator, airflow.providers.google.cloud.operators.translate.TranslateDatasetsListOperator instead`` + * ``Remove AutoMLDeleteDatasetOperator use airflow.providers.google.cloud.operators.vertex_ai.dataset.ListDatasetsOperator, airflow.providers.google.cloud.operators.translate.TranslateDatasetsListOperator instead`` + * ``Remove MLEngineCreateModelOperator use appropriate VertexAI operator instead`` + +* Hooks + + * ``Remove CloudAutoMLHook use airflow.providers.google.cloud.hooks.vertex_ai.auto_ml.AutoMLHook, airflow.providers.google.cloud.hooks.translate.TranslateHook instead`` + 18.1.0 ...... diff --git a/providers/google/docs/integration-logos/Cloud-AutoML.png b/providers/google/docs/integration-logos/Cloud-AutoML.png deleted file mode 100644 index b147e074b58fd..0000000000000 Binary files a/providers/google/docs/integration-logos/Cloud-AutoML.png and /dev/null differ diff --git a/providers/google/docs/operators/cloud/automl.rst b/providers/google/docs/operators/cloud/automl.rst deleted file mode 100644 index bd243e1165daf..0000000000000 --- a/providers/google/docs/operators/cloud/automl.rst +++ /dev/null @@ -1,244 +0,0 @@ - .. Licensed to the Apache Software Foundation (ASF) under one - or more contributor license agreements. See the NOTICE file - distributed with this work for additional information - regarding copyright ownership. The ASF licenses this file - to you under the Apache License, Version 2.0 (the - "License"); you may not use this file except in compliance - with the License. You may obtain a copy of the License at - - .. http://www.apache.org/licenses/LICENSE-2.0 - - .. Unless required by applicable law or agreed to in writing, - software distributed under the License is distributed on an - "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - KIND, either express or implied. See the License for the - specific language governing permissions and limitations - under the License. - -Google Cloud AutoML Operators -============================= - -.. warning:: - The AutoML API is deprecated. Planned removal date is September 30, 2025, but some operators might be deleted - earlier, according to the docs and deprecation warnings! - The replacement suggestions can be found in the deprecation warnings or in the doc below. - Please note that AutoML for translation API functionality has been moved to the Advanced Translation service, - the operators can be found at ``airflow.providers.google.cloud.operators.translate`` module. - -The `Google Cloud AutoML `__ -makes the power of machine learning available to you even if you have limited knowledge -of machine learning. You can use AutoML to build on Google's machine learning capabilities -to create your own custom machine learning models that are tailored to your business needs, -and then integrate those models into your applications and web sites. - -Prerequisite Tasks -^^^^^^^^^^^^^^^^^^ - -.. include:: /operators/_partials/prerequisite_tasks.rst - -.. _howto/operator:CloudAutoMLDocuments: -.. _howto/operator:AutoMLCreateDatasetOperator: -.. _howto/operator:AutoMLImportDataOperator: -.. _howto/operator:AutoMLTablesUpdateDatasetOperator: - -Creating Datasets -^^^^^^^^^^^^^^^^^ - -To create a Google AutoML dataset you can use -:class:`~airflow.providers.google.cloud.operators.automl.AutoMLCreateDatasetOperator`. -The operator returns dataset id in :ref:`XCom ` under ``dataset_id`` key. - -This operator is deprecated when running for text, video and vision prediction and will be removed after September 30, 2025. -All the functionality of legacy AutoML Natural Language, Vision, Video Intelligence and new features are -available on the Vertex AI platform. Please use -:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator` -:class:`~airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator`. - -After creating a dataset you can use it to import some data using -:class:`~airflow.providers.google.cloud.operators.automl.AutoMLImportDataOperator`. - -To update dataset you can use -:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesUpdateDatasetOperator`. - -.. warning:: - This operator is deprecated when running for text, video and vision prediction and will be removed soon. - All the functionality of legacy AutoML Natural Language, Vision, Video Intelligence and new features are - available on the Vertex AI platform. Please use - :class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.UpdateDatasetOperator` - -.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_dataset.py - :language: python - :dedent: 4 - :start-after: [START how_to_cloud_vertex_ai_update_dataset_operator] - :end-before: [END how_to_cloud_vertex_ai_update_dataset_operator] - -.. _howto/operator:AutoMLTablesListTableSpecsOperator: -.. _howto/operator:AutoMLTablesListColumnSpecsOperator: - -Listing Table And Columns Specs -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -To list table specs you can use -:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesListTableSpecsOperator`. - -To list column specs you can use -:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTablesListColumnSpecsOperator`. - -AutoML Tables related operators are deprecated. Please use related Vertex AI Tabular operators. - -.. _howto/operator:AutoMLTrainModelOperator: -.. _howto/operator:AutoMLGetModelOperator: -.. _howto/operator:AutoMLDeployModelOperator: -.. _howto/operator:AutoMLDeleteModelOperator: - -Operations On Models -^^^^^^^^^^^^^^^^^^^^ - -To create a Google AutoML model you can use -:class:`~airflow.providers.google.cloud.operators.automl.AutoMLTrainModelOperator`. -The operator will wait for the operation to complete. Additionally the operator -returns the id of model in :ref:`XCom ` under ``model_id`` key. - -.. warning:: - This operator is deprecated when running for text, video and vision prediction and will be removed after September 30, 2025. - All the functionality of legacy AutoML Natural Language, Vision, Video Intelligence and new features are - available on the Vertex AI platform. Please use - :class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTabularTrainingJobOperator`, - :class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLVideoTrainingJobOperator`, - :class:`~airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLImageTrainingJobOperator`, - :class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator`, - :class:`~airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator`. - -When running Vertex AI Operator for training data, please ensure that your data is correctly stored in Vertex AI -datasets. To create and import data to the dataset please use -:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator` -and -:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator` - -For the AutoML translation please use the -:class:`~airflow.providers.google.cloud.operators.translate.TranslateTextOperator` -or -:class:`~airflow.providers.google.cloud.operators.translate.TranslateTextBatchOperator`. - -To get existing model one can use -:class:`~airflow.providers.google.cloud.operators.automl.AutoMLGetModelOperator`. - -This operator deprecated for tables, video intelligence, vision and natural language is deprecated -and will be removed after 31.03.2024. Please use -:class:`~airflow.providers.google.cloud.operators.vertex_ai.model_service.GetModelOperator` instead. -You can find example on how to use VertexAI operators here: - -.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_model_service.py - :language: python - :dedent: 4 - :start-after: [START how_to_cloud_vertex_ai_get_model_operator] - :end-before: [END how_to_cloud_vertex_ai_get_model_operator] - -Once a model is created it could be deployed using -:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeployModelOperator`. - -This operator deprecated for tables, video intelligence, vision and natural language is deprecated -and will be removed after 31.03.2024. Please use -:class:`airflow.providers.google.cloud.operators.vertex_ai.endpoint_service.DeployModelOperator` instead. -You can find example on how to use VertexAI operators here: - -.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_endpoint.py - :language: python - :dedent: 4 - :start-after: [START how_to_cloud_vertex_ai_deploy_model_operator] - :end-before: [END how_to_cloud_vertex_ai_deploy_model_operator] - -If you wish to delete a model you can use -:class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteModelOperator`. - -This operator deprecated for tables, video intelligence, vision and natural language is deprecated -and will be removed after 31.03.2024. Please use -:class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.DeleteModelOperator` instead. -You can find example on how to use VertexAI operators here: - -.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_model_service.py - :language: python - :dedent: 4 - :start-after: [START how_to_cloud_vertex_ai_delete_model_operator] - :end-before: [END how_to_cloud_vertex_ai_delete_model_operator] - -.. _howto/operator:AutoMLPredictOperator: - -Making Predictions -^^^^^^^^^^^^^^^^^^ - -To obtain predictions from Google Cloud AutoML model you can use -:class:`~airflow.providers.google.cloud.operators.automl.AutoMLPredictOperator`. In the first case -the model must be deployed. - - -For tables, video intelligence, vision and natural language you can use the following operators: - -:class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.CreateBatchPredictionJobOperator`, -:class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.GetBatchPredictionJobOperator`, -:class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.ListBatchPredictionJobsOperator`, -:class:`airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.DeleteBatchPredictionJobOperator`. -You can find examples on how to use VertexAI operators here: - -.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_batch_prediction_job.py - :language: python - :dedent: 4 - :start-after: [START how_to_cloud_vertex_ai_create_batch_prediction_job_operator] - :end-before: [END how_to_cloud_vertex_ai_create_batch_prediction_job_operator] - -.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_batch_prediction_job.py - :language: python - :dedent: 4 - :start-after: [START how_to_cloud_vertex_ai_list_batch_prediction_job_operator] - :end-before: [END how_to_cloud_vertex_ai_list_batch_prediction_job_operator] - -.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_batch_prediction_job.py - :language: python - :dedent: 4 - :start-after: [START how_to_cloud_vertex_ai_delete_batch_prediction_job_operator] - :end-before: [END how_to_cloud_vertex_ai_delete_batch_prediction_job_operator] - -.. _howto/operator:AutoMLListDatasetOperator: -.. _howto/operator:AutoMLDeleteDatasetOperator: - -Listing And Deleting Datasets -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -You can get a list of AutoML datasets using -:class:`~airflow.providers.google.cloud.operators.automl.AutoMLListDatasetOperator`. The operator returns list -of datasets ids in :ref:`XCom ` under ``dataset_id_list`` key. - -This operator deprecated for tables, video intelligence, vision and natural language is deprecated -and will be removed after 31.03.2024. Please use -:class:`~airflow.providers.google.cloud.operators.vertex_ai.dataset.ListDatasetsOperator`, -:class:`~airflow.providers.google.cloud.operators.translate.TranslateDatasetsListOperator` -instead. -You can find example on how to use VertexAI operators here: - -.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_dataset.py - :language: python - :dedent: 4 - :start-after: [START how_to_cloud_vertex_ai_list_dataset_operator] - :end-before: [END how_to_cloud_vertex_ai_list_dataset_operator] - -To delete a dataset you can use :class:`~airflow.providers.google.cloud.operators.automl.AutoMLDeleteDatasetOperator`. -The delete operator allows also to pass list or coma separated string of datasets ids to be deleted. - -This operator deprecated for tables, video intelligence, vision and natural language is deprecated -and will be removed after 31.03.2024. Please use -:class:`airflow.providers.google.cloud.operators.vertex_ai.dataset.DeleteDatasetOperator` instead. -You can find example on how to use VertexAI operators here: - -.. exampleinclude:: /../../google/tests/system/google/cloud/vertex_ai/example_vertex_ai_dataset.py - :language: python - :dedent: 4 - :start-after: [START how_to_cloud_vertex_ai_delete_dataset_operator] - :end-before: [END how_to_cloud_vertex_ai_delete_dataset_operator] - -Reference -^^^^^^^^^ - -For further information, look at: - -* `Client Library Documentation `__ -* `Product Documentation `__ diff --git a/providers/google/docs/operators/cloud/mlengine.rst b/providers/google/docs/operators/cloud/mlengine.rst index 1fc63b9528d76..407954011efe2 100644 --- a/providers/google/docs/operators/cloud/mlengine.rst +++ b/providers/google/docs/operators/cloud/mlengine.rst @@ -45,22 +45,8 @@ the Vertex AI platform. Creating a model ^^^^^^^^^^^^^^^^ -A model is a container that can hold multiple model versions. A new model can be created through the -:class:`~airflow.providers.google.cloud.operators.mlengine.MLEngineCreateModelOperator`. -The ``model`` field should be defined with a dictionary containing the information about the model. -``name`` is a required field in this dictionary. - -.. warning:: - This operator is deprecated. The model is created as a result of running Vertex AI operators that create training jobs - of any types. For example, you can use - :class:`~airflow.providers.google.cloud.operators.vertex_ai.custom_job.CreateCustomPythonPackageTrainingJobOperator`. - The result of running this operator will be ready-to-use model saved in Model Registry. - -.. exampleinclude:: /../../google/tests/system/google/cloud/ml_engine/example_mlengine.py - :language: python - :dedent: 4 - :start-after: [START howto_operator_create_custom_python_training_job_v1] - :end-before: [END howto_operator_create_custom_python_training_job_v1] +This function is deprecated. All the functionality of legacy MLEngine and new features are available on +the Vertex AI platform. Getting a model ^^^^^^^^^^^^^^^ @@ -72,18 +58,13 @@ Creating model versions This function is deprecated. All the functionality of legacy MLEngine and new features are available on the Vertex AI platform. -For model versioning please check: -`Model versioning with Vertex AI -`__ Managing model versions ^^^^^^^^^^^^^^^^^^^^^^^ This function is deprecated. All the functionality of legacy MLEngine and new features are available on the Vertex AI platform. -For model versioning please check: -`Model versioning with Vertex AI -`__ + Making predictions ^^^^^^^^^^^^^^^^^^ @@ -103,14 +84,3 @@ Evaluating a model This function is deprecated. All the functionality of legacy MLEngine and new features are available on the Vertex AI platform. -To create and view Model Evaluation, please check the documentation: -`Evaluate models using Vertex AI -`__ - -Reference -^^^^^^^^^ - -For further information, look at: - -* `Client Library Documentation `__ -* `Product Documentation `__ diff --git a/providers/google/provider.yaml b/providers/google/provider.yaml index 09f5ee4ff87af..5c98aacfa6529 100644 --- a/providers/google/provider.yaml +++ b/providers/google/provider.yaml @@ -127,12 +127,6 @@ integrations: how-to-guide: - /docs/apache-airflow-providers-google/operators/ads.rst tags: [gmp] - - integration-name: Google AutoML - external-doc-url: https://cloud.google.com/automl/ - how-to-guide: - - /docs/apache-airflow-providers-google/operators/cloud/automl.rst - logo: /docs/integration-logos/Cloud-AutoML.png - tags: [gcp] - integration-name: Google BigQuery Data Transfer Service external-doc-url: https://cloud.google.com/bigquery/transfer/ logo: /docs/integration-logos/BigQuery.png @@ -473,9 +467,6 @@ operators: - integration-name: Google Cloud AlloyDB python-modules: - airflow.providers.google.cloud.operators.alloy_db - - integration-name: Google AutoML - python-modules: - - airflow.providers.google.cloud.operators.automl - integration-name: Google BigQuery python-modules: - airflow.providers.google.cloud.operators.bigquery @@ -539,9 +530,6 @@ operators: - integration-name: Google Kubernetes Engine python-modules: - airflow.providers.google.cloud.operators.kubernetes_engine - - integration-name: Google Machine Learning Engine - python-modules: - - airflow.providers.google.cloud.operators.mlengine - integration-name: Google Cloud Natural Language python-modules: - airflow.providers.google.cloud.operators.natural_language @@ -729,9 +717,6 @@ hooks: - integration-name: Google Ads python-modules: - airflow.providers.google.ads.hooks.ads - - integration-name: Google AutoML - python-modules: - - airflow.providers.google.cloud.hooks.automl - integration-name: Google BigQuery python-modules: - airflow.providers.google.cloud.hooks.bigquery diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/automl.py b/providers/google/src/airflow/providers/google/cloud/hooks/automl.py deleted file mode 100644 index a60262c21b900..0000000000000 --- a/providers/google/src/airflow/providers/google/cloud/hooks/automl.py +++ /dev/null @@ -1,673 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -This module contains a Google AutoML hook. - -.. spelling:word-list:: - - PredictResponse -""" - -from __future__ import annotations - -from collections.abc import Sequence -from functools import cached_property -from typing import TYPE_CHECKING - -from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault -from google.cloud.automl_v1beta1 import ( - AutoMlClient, - BatchPredictInputConfig, - BatchPredictOutputConfig, - Dataset, - ExamplePayload, - ImageObjectDetectionModelDeploymentMetadata, - InputConfig, - Model, - PredictionServiceClient, - PredictResponse, -) - -from airflow.exceptions import AirflowProviderDeprecationWarning -from airflow.providers.google.common.consts import CLIENT_INFO -from airflow.providers.google.common.deprecated import deprecated -from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook -from airflow.providers.google.common.hooks.operation_helpers import OperationHelper - -if TYPE_CHECKING: - from google.api_core.operation import Operation - from google.api_core.retry import Retry - from google.cloud.automl_v1beta1.services.auto_ml.pagers import ( - ListColumnSpecsPager, - ListDatasetsPager, - ListTableSpecsPager, - ) - from google.protobuf.field_mask_pb2 import FieldMask - - -@deprecated( - planned_removal_date="September 30, 2025", - use_instead="airflow.providers.google.cloud.hooks.vertex_ai.auto_ml.AutoMLHook, " - "airflow.providers.google.cloud.hooks.translate.TranslateHook", - category=AirflowProviderDeprecationWarning, -) -class CloudAutoMLHook(GoogleBaseHook, OperationHelper): - """ - Google Cloud AutoML hook. - - All the methods in the hook where project_id is used must be called with - keyword arguments rather than positional. - """ - - def __init__( - self, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__( - gcp_conn_id=gcp_conn_id, - impersonation_chain=impersonation_chain, - **kwargs, - ) - self._client: AutoMlClient | None = None - - @staticmethod - def extract_object_id(obj: dict) -> str: - """Return unique id of the object.""" - return obj["name"].rpartition("/")[-1] - - def get_conn(self) -> AutoMlClient: - """ - Retrieve connection to AutoML. - - :return: Google Cloud AutoML client object. - """ - if self._client is None: - self._client = AutoMlClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) - return self._client - - @cached_property - def prediction_client(self) -> PredictionServiceClient: - """ - Creates PredictionServiceClient. - - :return: Google Cloud AutoML PredictionServiceClient client object. - """ - return PredictionServiceClient(credentials=self.get_credentials(), client_info=CLIENT_INFO) - - @GoogleBaseHook.fallback_to_default_project_id - def create_model( - self, - model: dict | Model, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - retry: Retry | _MethodDefault = DEFAULT, - ) -> Operation: - """ - Create a model_id and returns a Model in the `response` field when it completes. - - When you create a model, several model evaluations are created for it: - a global evaluation, and one evaluation for each annotation spec. - - :param model: The model_id to create. If a dict is provided, it must be of the same form - as the protobuf message `google.cloud.automl_v1beta1.types.Model` - :param project_id: ID of the Google Cloud project where model will be created if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests - will not be retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. - Note that if `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - - :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance - """ - client = self.get_conn() - parent = f"projects/{project_id}/locations/{location}" - return client.create_model( - request={"parent": parent, "model": model}, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - - @GoogleBaseHook.fallback_to_default_project_id - def batch_predict( - self, - model_id: str, - input_config: dict | BatchPredictInputConfig, - output_config: dict | BatchPredictOutputConfig, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - params: dict[str, str] | None = None, - retry: Retry | _MethodDefault = DEFAULT, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - ) -> Operation: - """ - Perform a batch prediction and returns a long-running operation object. - - Unlike the online `Predict`, batch prediction result won't be immediately - available in the response. Instead, a long-running operation object is returned. - - :param model_id: Name of the model_id requested to serve the batch prediction. - :param input_config: Required. The input configuration for batch prediction. - If a dict is provided, it must be of the same form as the protobuf message - `google.cloud.automl_v1beta1.types.BatchPredictInputConfig` - :param output_config: Required. The Configuration specifying where output predictions should be - written. If a dict is provided, it must be of the same form as the protobuf message - `google.cloud.automl_v1beta1.types.BatchPredictOutputConfig` - :param params: Additional domain-specific parameters for the predictions, any string must be up to - 25000 characters long. - :param project_id: ID of the Google Cloud project where model is located if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - - :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance - """ - client = self.prediction_client - name = f"projects/{project_id}/locations/{location}/models/{model_id}" - result = client.batch_predict( - request={ - "name": name, - "input_config": input_config, - "output_config": output_config, - "params": params, - }, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def predict( - self, - model_id: str, - payload: dict | ExamplePayload, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - params: dict[str, str] | None = None, - retry: Retry | _MethodDefault = DEFAULT, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - ) -> PredictResponse: - """ - Perform an online prediction and returns the prediction result in the response. - - :param model_id: Name of the model_id requested to serve the prediction. - :param payload: Required. Payload to perform a prediction on. The payload must match the problem type - that the model_id was trained to solve. If a dict is provided, it must be of - the same form as the protobuf message `google.cloud.automl_v1beta1.types.ExamplePayload` - :param params: Additional domain-specific parameters, any string must be up to 25000 characters long. - :param project_id: ID of the Google Cloud project where model is located if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - - :return: `google.cloud.automl_v1beta1.types.PredictResponse` instance - """ - client = self.prediction_client - name = f"projects/{project_id}/locations/{location}/models/{model_id}" - result = client.predict( - request={"name": name, "payload": payload, "params": params}, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def create_dataset( - self, - dataset: dict | Dataset, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - retry: Retry | _MethodDefault = DEFAULT, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - ) -> Dataset: - """ - Create a dataset. - - :param dataset: The dataset to create. If a dict is provided, it must be of the - same form as the protobuf message Dataset. - :param project_id: ID of the Google Cloud project where dataset is located if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - - :return: `google.cloud.automl_v1beta1.types.Dataset` instance. - """ - client = self.get_conn() - parent = f"projects/{project_id}/locations/{location}" - result = client.create_dataset( - request={"parent": parent, "dataset": dataset}, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def import_data( - self, - dataset_id: str, - location: str, - input_config: dict | InputConfig, - project_id: str = PROVIDE_PROJECT_ID, - retry: Retry | _MethodDefault = DEFAULT, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - ) -> Operation: - """ - Import data into a dataset. For Tables this method can only be called on an empty Dataset. - - :param dataset_id: Name of the AutoML dataset. - :param input_config: The desired input location and its domain specific semantics, if any. - If a dict is provided, it must be of the same form as the protobuf message InputConfig. - :param project_id: ID of the Google Cloud project where dataset is located if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - - :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance - """ - client = self.get_conn() - name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" - result = client.import_data( - request={"name": name, "input_config": input_config}, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def list_column_specs( - self, - dataset_id: str, - table_spec_id: str, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - field_mask: dict | FieldMask | None = None, - filter_: str | None = None, - page_size: int | None = None, - retry: Retry | _MethodDefault = DEFAULT, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - ) -> ListColumnSpecsPager: - """ - List column specs in a table spec. - - :param dataset_id: Name of the AutoML dataset. - :param table_spec_id: table_spec_id for path builder. - :param field_mask: Mask specifying which fields to read. If a dict is provided, it must be of the same - form as the protobuf message `google.cloud.automl_v1beta1.types.FieldMask` - :param filter_: Filter expression, see go/filtering. - :param page_size: The maximum number of resources contained in the - underlying API response. If page streaming is performed per - resource, this parameter does not affect the return value. If page - streaming is performed per-page, this determines the maximum number - of resources in a page. - :param project_id: ID of the Google Cloud project where dataset is located if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - - :return: `google.cloud.automl_v1beta1.types.ColumnSpec` instance. - """ - client = self.get_conn() - parent = client.table_spec_path( - project=project_id, - location=location, - dataset=dataset_id, - table_spec=table_spec_id, - ) - result = client.list_column_specs( - request={"parent": parent, "field_mask": field_mask, "filter": filter_, "page_size": page_size}, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def get_model( - self, - model_id: str, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - retry: Retry | _MethodDefault = DEFAULT, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - ) -> Model: - """ - Get a AutoML model. - - :param model_id: Name of the model. - :param project_id: ID of the Google Cloud project where model is located if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - - :return: `google.cloud.automl_v1beta1.types.Model` instance. - """ - client = self.get_conn() - name = f"projects/{project_id}/locations/{location}/models/{model_id}" - result = client.get_model( - request={"name": name}, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def delete_model( - self, - model_id: str, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - retry: Retry | _MethodDefault = DEFAULT, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - ) -> Operation: - """ - Delete a AutoML model. - - :param model_id: Name of the model. - :param project_id: ID of the Google Cloud project where model is located if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - - :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance. - """ - client = self.get_conn() - name = f"projects/{project_id}/locations/{location}/models/{model_id}" - result = client.delete_model( - request={"name": name}, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - def update_dataset( - self, - dataset: dict | Dataset, - update_mask: dict | FieldMask | None = None, - retry: Retry | _MethodDefault = DEFAULT, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - ) -> Dataset: - """ - Update a dataset. - - :param dataset: The dataset which replaces the resource on the server. - If a dict is provided, it must be of the same form as the protobuf message Dataset. - :param update_mask: The update mask applies to the resource. If a dict is provided, it must - be of the same form as the protobuf message FieldMask. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - - :return: `google.cloud.automl_v1beta1.types.Dataset` instance.. - """ - client = self.get_conn() - result = client.update_dataset( - request={"dataset": dataset, "update_mask": update_mask}, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def deploy_model( - self, - model_id: str, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - image_detection_metadata: ImageObjectDetectionModelDeploymentMetadata | dict | None = None, - retry: Retry | _MethodDefault = DEFAULT, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - ) -> Operation: - """ - Deploys a model. - - If a model is already deployed, deploying it with the same parameters - has no effect. Deploying with different parameters (as e.g. changing node_number) will - reset the deployment state without pausing the model_id's availability. - - Only applicable for Text Classification, Image Object Detection and Tables; all other - domains manage deployment automatically. - - :param model_id: Name of the model requested to serve the prediction. - :param image_detection_metadata: Model deployment metadata specific to Image Object Detection. - If a dict is provided, it must be of the same form as the protobuf message - ImageObjectDetectionModelDeploymentMetadata - :param project_id: ID of the Google Cloud project where model will be created if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - - :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance. - """ - client = self.get_conn() - name = f"projects/{project_id}/locations/{location}/models/{model_id}" - result = client.deploy_model( - request={ - "name": name, - "image_object_detection_model_deployment_metadata": image_detection_metadata, - }, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - def list_table_specs( - self, - dataset_id: str, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - filter_: str | None = None, - page_size: int | None = None, - retry: Retry | _MethodDefault = DEFAULT, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - ) -> ListTableSpecsPager: - """ - List table specs in a dataset_id. - - :param dataset_id: Name of the dataset. - :param filter_: Filter expression, see go/filtering. - :param page_size: The maximum number of resources contained in the - underlying API response. If page streaming is performed per - resource, this parameter does not affect the return value. If page - streaming is performed per-page, this determines the maximum number - of resources in a page. - :param project_id: ID of the Google Cloud project where dataset is located if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - - :return: A `google.gax.PageIterator` instance. By default, this - is an iterable of `google.cloud.automl_v1beta1.types.TableSpec` instances. - This object can also be configured to iterate over the pages - of the response through the `options` parameter. - """ - client = self.get_conn() - parent = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" - result = client.list_table_specs( - request={"parent": parent, "filter": filter_, "page_size": page_size}, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def list_datasets( - self, - location: str, - project_id: str, - retry: Retry | _MethodDefault = DEFAULT, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - ) -> ListDatasetsPager: - """ - List datasets in a project. - - :param project_id: ID of the Google Cloud project where dataset is located if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - - :return: A `google.gax.PageIterator` instance. By default, this - is an iterable of `google.cloud.automl_v1beta1.types.Dataset` instances. - This object can also be configured to iterate over the pages - of the response through the `options` parameter. - """ - client = self.get_conn() - parent = f"projects/{project_id}/locations/{location}" - result = client.list_datasets( - request={"parent": parent}, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def delete_dataset( - self, - dataset_id: str, - location: str, - project_id: str, - retry: Retry | _MethodDefault = DEFAULT, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - ) -> Operation: - """ - Delete a dataset and all of its contents. - - :param dataset_id: ID of dataset to be deleted. - :param project_id: ID of the Google Cloud project where dataset is located if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - - :return: `google.cloud.automl_v1beta1.types._OperationFuture` instance - """ - client = self.get_conn() - name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" - result = client.delete_dataset( - request={"name": name}, - retry=retry, - timeout=timeout, - metadata=metadata, - ) - return result - - @GoogleBaseHook.fallback_to_default_project_id - def get_dataset( - self, - dataset_id: str, - location: str, - project_id: str, - retry: Retry | _MethodDefault = DEFAULT, - timeout: float | None = None, - metadata: Sequence[tuple[str, str]] = (), - ) -> Dataset: - """ - Retrieve the dataset for the given dataset_id. - - :param dataset_id: ID of dataset to be retrieved. - :param location: The location of the project. - :param project_id: ID of the Google Cloud project where dataset is located if None then - default project_id is used. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - - :return: `google.cloud.automl_v1beta1.types.dataset.Dataset` instance. - """ - client = self.get_conn() - name = f"projects/{project_id}/locations/{location}/datasets/{dataset_id}" - return client.get_dataset( - request={"name": name}, - retry=retry, - timeout=timeout, - metadata=metadata, - ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/automl.py b/providers/google/src/airflow/providers/google/cloud/operators/automl.py deleted file mode 100644 index 905b0acf2c074..0000000000000 --- a/providers/google/src/airflow/providers/google/cloud/operators/automl.py +++ /dev/null @@ -1,1364 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module contains Google AutoML operators.""" - -from __future__ import annotations - -import ast -from collections.abc import Sequence -from functools import cached_property -from typing import TYPE_CHECKING, cast - -from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault -from google.cloud.automl_v1beta1 import ( - ColumnSpec, - Dataset, - Model, - PredictResponse, - TableSpec, -) - -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning -from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook -from airflow.providers.google.cloud.hooks.vertex_ai.prediction_service import PredictionServiceHook -from airflow.providers.google.cloud.links.translate import ( - TranslationDatasetListLink, - TranslationLegacyDatasetLink, - TranslationLegacyModelLink, - TranslationLegacyModelPredictLink, - TranslationLegacyModelTrainLink, -) -from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator -from airflow.providers.google.common.deprecated import deprecated -from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID - -if TYPE_CHECKING: - from google.api_core.retry import Retry - - from airflow.providers.common.compat.sdk import Context - -MetaData = Sequence[tuple[str, str]] - - -@deprecated( - planned_removal_date="September 30, 2025", - use_instead="airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTabularTrainingJobOperator, " - "airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLVideoTrainingJobOperator, " - "airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLImageTrainingJobOperator, " - "airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator, " - "airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator", - category=AirflowProviderDeprecationWarning, -) -class AutoMLTrainModelOperator(GoogleCloudBaseOperator): - """ - Creates Google Cloud AutoML model. - - .. warning:: - AutoMLTrainModelOperator for tables, video intelligence, vision and natural language has been deprecated - and no longer available. Please use - :class:`airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLTabularTrainingJobOperator`, - :class:`airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLVideoTrainingJobOperator`, - :class:`airflow.providers.google.cloud.operators.vertex_ai.auto_ml.CreateAutoMLImageTrainingJobOperator`, - :class:`airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator`, - :class:`airflow.providers.google.cloud.operators.translate.TranslateCreateModelOperator`. - instead. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:AutoMLTrainModelOperator` - - :param model: Model definition. - :param project_id: ID of the Google Cloud project where model will be created if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use to connect to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields: Sequence[str] = ( - "model", - "location", - "project_id", - "impersonation_chain", - ) - operator_extra_links = ( - TranslationLegacyModelTrainLink(), - TranslationLegacyModelLink(), - ) - - def __init__( - self, - *, - model: dict, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - metadata: MetaData = (), - timeout: float | None = None, - retry: Retry | _MethodDefault = DEFAULT, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.model = model - self.location = location - self.project_id = project_id - self.metadata = metadata - self.timeout = timeout - self.retry = retry - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - def execute(self, context: Context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - self.log.info("Creating model %s...", self.model["display_name"]) - operation = hook.create_model( - model=self.model, - location=self.location, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - project_id = self.project_id or hook.project_id - if project_id: - TranslationLegacyModelTrainLink.persist( - context=context, - dataset_id=self.model["dataset_id"], - project_id=project_id, - location=self.location, - ) - operation_result = hook.wait_for_operation(timeout=self.timeout, operation=operation) - result = Model.to_dict(operation_result) - model_id = hook.extract_object_id(result) - self.log.info("Model is created, model_id: %s", model_id) - - context["task_instance"].xcom_push(key="model_id", value=model_id) - if project_id: - TranslationLegacyModelLink.persist( - context=context, - dataset_id=self.model["dataset_id"] or "-", - model_id=model_id, - project_id=project_id, - location=self.location, - ) - return result - - -@deprecated( - planned_removal_date="September 30, 2025", - use_instead="airflow.providers.google.cloud.operators.translate.TranslateTextOperator", - category=AirflowProviderDeprecationWarning, -) -class AutoMLPredictOperator(GoogleCloudBaseOperator): - """ - Runs prediction operation on Google Cloud AutoML. - - .. warning:: - AutoMLPredictOperator for text, image, and video prediction has been deprecated. - Please use endpoint_id param instead of model_id param. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:AutoMLPredictOperator` - - :param model_id: Name of the model requested to serve the batch prediction. - :param endpoint_id: Name of the endpoint used for the prediction. - :param payload: Name of the model used for the prediction. - :param project_id: ID of the Google Cloud project where model is located if None then - default project_id is used. - :param location: The location of the project. - :param operation_params: Additional domain-specific parameters for the predictions. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use to connect to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields: Sequence[str] = ( - "model_id", - "location", - "project_id", - "impersonation_chain", - ) - operator_extra_links = (TranslationLegacyModelPredictLink(),) - - def __init__( - self, - *, - model_id: str | None = None, - endpoint_id: str | None = None, - location: str, - payload: dict, - operation_params: dict[str, str] | None = None, - instances: list[str] | None = None, - project_id: str = PROVIDE_PROJECT_ID, - metadata: MetaData = (), - timeout: float | None = None, - retry: Retry | _MethodDefault = DEFAULT, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.model_id = model_id - self.endpoint_id = endpoint_id - self.operation_params = operation_params - self.instances = instances - self.location = location - self.project_id = project_id - self.metadata = metadata - self.timeout = timeout - self.retry = retry - self.payload = payload - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - @cached_property - def hook(self) -> CloudAutoMLHook | PredictionServiceHook: - if self.model_id: - return CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - # endpoint_id defined - return PredictionServiceHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - - @cached_property - def model(self) -> Model | None: - if self.model_id: - hook = cast("CloudAutoMLHook", self.hook) - return hook.get_model( - model_id=self.model_id, - location=self.location, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - return None - - def execute(self, context: Context): - if self.model_id is None and self.endpoint_id is None: - raise AirflowException("You must specify model_id or endpoint_id!") - - hook = self.hook - if self.model_id: - result = hook.predict( - model_id=self.model_id, - payload=self.payload, - location=self.location, - project_id=self.project_id, - params=self.operation_params, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - else: # self.endpoint_id is defined - result = hook.predict( - endpoint_id=self.endpoint_id, - instances=self.instances, - payload=self.payload, - location=self.location, - project_id=self.project_id, - parameters=self.operation_params, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - - project_id = self.project_id or hook.project_id - dataset_id: str | None = self.model.dataset_id if self.model else None - if project_id and self.model_id and dataset_id: - TranslationLegacyModelPredictLink.persist( - context=context, - model_id=self.model_id, - dataset_id=dataset_id, - project_id=project_id, - location=self.location, - ) - return PredictResponse.to_dict(result) - - -@deprecated( - planned_removal_date="September 30, 2025", - use_instead="airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator, " - "airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator", - category=AirflowProviderDeprecationWarning, -) -class AutoMLCreateDatasetOperator(GoogleCloudBaseOperator): - """ - Creates a Google Cloud AutoML dataset. - - AutoMLCreateDatasetOperator for tables, video intelligence, vision and natural language has been - deprecated and no longer available. Please use - :class:`airflow.providers.google.cloud.operators.vertex_ai.dataset.CreateDatasetOperator`, - :class:`airflow.providers.google.cloud.operators.translate.TranslateCreateDatasetOperator` instead. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:AutoMLCreateDatasetOperator` - - :param dataset: The dataset to create. If a dict is provided, it must be of the - same form as the protobuf message Dataset. - :param project_id: ID of the Google Cloud project where dataset is located if None then - default project_id is used. - :param location: The location of the project. - :param params: Additional domain-specific parameters for the predictions. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use to connect to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields: Sequence[str] = ( - "dataset", - "location", - "project_id", - "impersonation_chain", - ) - operator_extra_links = (TranslationLegacyDatasetLink(),) - - def __init__( - self, - *, - dataset: dict, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - metadata: MetaData = (), - timeout: float | None = None, - retry: Retry | _MethodDefault = DEFAULT, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.dataset = dataset - self.location = location - self.project_id = project_id - self.metadata = metadata - self.timeout = timeout - self.retry = retry - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - def execute(self, context: Context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - self.log.info("Creating dataset %s...", self.dataset) - result = hook.create_dataset( - dataset=self.dataset, - location=self.location, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - result = Dataset.to_dict(result) - dataset_id = hook.extract_object_id(result) - self.log.info("Creating completed. Dataset id: %s", dataset_id) - - context["task_instance"].xcom_push(key="dataset_id", value=dataset_id) - project_id = self.project_id or hook.project_id - if project_id: - TranslationLegacyDatasetLink.persist( - context=context, - dataset_id=dataset_id, - project_id=project_id, - location=self.location, - ) - return result - - -@deprecated( - planned_removal_date="September 30, 2025", - use_instead="airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator, " - "airflow.providers.google.cloud.operators.translate.TranslateImportDataOperator", - category=AirflowProviderDeprecationWarning, -) -class AutoMLImportDataOperator(GoogleCloudBaseOperator): - """ - Imports data to a Google Cloud AutoML dataset. - - .. warning:: - AutoMLImportDataOperator for tables, video intelligence, vision and natural language has been deprecated - and no longer available. Please use - :class:`airflow.providers.google.cloud.operators.vertex_ai.dataset.ImportDataOperator` instead. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:AutoMLImportDataOperator` - - :param dataset_id: ID of dataset to be updated. - :param input_config: The desired input location and its domain specific semantics, if any. - If a dict is provided, it must be of the same form as the protobuf message InputConfig. - :param project_id: ID of the Google Cloud project where dataset is located if None then - default project_id is used. - :param location: The location of the project. - :param params: Additional domain-specific parameters for the predictions. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use to connect to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields: Sequence[str] = ( - "dataset_id", - "input_config", - "location", - "project_id", - "impersonation_chain", - ) - operator_extra_links = (TranslationLegacyDatasetLink(),) - - def __init__( - self, - *, - dataset_id: str, - location: str, - input_config: dict, - project_id: str = PROVIDE_PROJECT_ID, - metadata: MetaData = (), - timeout: float | None = None, - retry: Retry | _MethodDefault = DEFAULT, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.dataset_id = dataset_id - self.input_config = input_config - self.location = location - self.project_id = project_id - self.metadata = metadata - self.timeout = timeout - self.retry = retry - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - def execute(self, context: Context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - hook.get_dataset( - dataset_id=self.dataset_id, - location=self.location, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - self.log.info("Importing data to dataset...") - operation = hook.import_data( - dataset_id=self.dataset_id, - input_config=self.input_config, - location=self.location, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - hook.wait_for_operation(timeout=self.timeout, operation=operation) - self.log.info("Import is completed") - project_id = self.project_id or hook.project_id - if project_id: - TranslationLegacyDatasetLink.persist( - context=context, - dataset_id=self.dataset_id, - project_id=project_id, - location=self.location, - ) - - -@deprecated( - planned_removal_date="September 30, 2025", - category=AirflowProviderDeprecationWarning, - reason="Shutdown of legacy version of AutoML Tables on March 31, 2024.", -) -class AutoMLTablesListColumnSpecsOperator(GoogleCloudBaseOperator): - """ - Lists column specs in a table. - - .. warning:: - Operator AutoMLTablesListColumnSpecsOperator has been deprecated due to shutdown of - a legacy version of AutoML Tables on March 31, 2024. For additional information - see: https://cloud.google.com/automl-tables/docs/deprecations. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:AutoMLTablesListColumnSpecsOperator` - - :param dataset_id: Name of the dataset. - :param table_spec_id: table_spec_id for path builder. - :param field_mask: Mask specifying which fields to read. If a dict is provided, it must be of the same - form as the protobuf message `google.cloud.automl_v1beta1.types.FieldMask` - :param filter_: Filter expression, see go/filtering. - :param page_size: The maximum number of resources contained in the - underlying API response. If page streaming is performed per - resource, this parameter does not affect the return value. If page - streaming is performed per page, this determines the maximum number - of resources in a page. - :param project_id: ID of the Google Cloud project where dataset is located if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use to connect to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields: Sequence[str] = ( - "dataset_id", - "table_spec_id", - "field_mask", - "filter_", - "location", - "project_id", - "impersonation_chain", - ) - operator_extra_links = (TranslationLegacyDatasetLink(),) - - def __init__( - self, - *, - dataset_id: str, - table_spec_id: str, - location: str, - field_mask: dict | None = None, - filter_: str | None = None, - page_size: int | None = None, - project_id: str = PROVIDE_PROJECT_ID, - metadata: MetaData = (), - timeout: float | None = None, - retry: Retry | _MethodDefault = DEFAULT, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.dataset_id = dataset_id - self.table_spec_id = table_spec_id - self.field_mask = field_mask - self.filter_ = filter_ - self.page_size = page_size - self.location = location - self.project_id = project_id - self.metadata = metadata - self.timeout = timeout - self.retry = retry - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - def execute(self, context: Context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - self.log.info("Requesting column specs.") - page_iterator = hook.list_column_specs( - dataset_id=self.dataset_id, - table_spec_id=self.table_spec_id, - field_mask=self.field_mask, - filter_=self.filter_, - page_size=self.page_size, - location=self.location, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - result = [ColumnSpec.to_dict(spec) for spec in page_iterator] - self.log.info("Columns specs obtained.") - project_id = self.project_id or hook.project_id - if project_id: - TranslationLegacyDatasetLink.persist( - context=context, - dataset_id=self.dataset_id, - project_id=project_id, - location=self.location, - ) - return result - - -@deprecated( - planned_removal_date="September 30, 2025", - use_instead="airflow.providers.google.cloud.operators.vertex_ai.dataset.UpdateDatasetOperator", - category=AirflowProviderDeprecationWarning, - reason="Shutdown of legacy version of AutoML Tables on March 31, 2024.", -) -class AutoMLTablesUpdateDatasetOperator(GoogleCloudBaseOperator): - """ - Updates a dataset. - - .. warning:: - Operator AutoMLTablesUpdateDatasetOperator has been deprecated due to shutdown of - a legacy version of AutoML Tables on March 31, 2024. For additional information - see: https://cloud.google.com/automl-tables/docs/deprecations. - Please use :class:`airflow.providers.google.cloud.operators.vertex_ai.dataset.UpdateDatasetOperator` - instead. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:AutoMLTablesUpdateDatasetOperator` - - :param dataset: The dataset which replaces the resource on the server. - If a dict is provided, it must be of the same form as the protobuf message Dataset. - :param update_mask: The update mask applies to the resource. If a dict is provided, it must - be of the same form as the protobuf message FieldMask. - :param location: The location of the project. - :param params: Additional domain-specific parameters for the predictions. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use to connect to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields: Sequence[str] = ( - "dataset", - "update_mask", - "location", - "impersonation_chain", - ) - operator_extra_links = (TranslationLegacyDatasetLink(),) - - def __init__( - self, - *, - dataset: dict, - location: str, - update_mask: dict | None = None, - metadata: MetaData = (), - timeout: float | None = None, - retry: Retry | _MethodDefault = DEFAULT, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.dataset = dataset - self.update_mask = update_mask - self.location = location - self.metadata = metadata - self.timeout = timeout - self.retry = retry - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - def execute(self, context: Context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - self.log.info("Updating AutoML dataset %s.", self.dataset["name"]) - result = hook.update_dataset( - dataset=self.dataset, - update_mask=self.update_mask, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - self.log.info("Dataset updated.") - project_id = hook.project_id - if project_id: - TranslationLegacyDatasetLink.persist( - context=context, - dataset_id=hook.extract_object_id(self.dataset), - project_id=project_id, - location=self.location, - ) - return Dataset.to_dict(result) - - -@deprecated( - planned_removal_date="September 30, 2025", - use_instead="airflow.providers.google.cloud.operators.vertex_ai.model_service.GetModelOperator", - category=AirflowProviderDeprecationWarning, -) -class AutoMLGetModelOperator(GoogleCloudBaseOperator): - """ - Get Google Cloud AutoML model. - - .. warning:: - AutoMLGetModelOperator for tables, video intelligence, vision and natural language has been deprecated - and no longer available. Please use - :class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.GetModelOperator` instead. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:AutoMLGetModelOperator` - - :param model_id: Name of the model requested to serve the prediction. - :param project_id: ID of the Google Cloud project where model is located if None then - default project_id is used. - :param location: The location of the project. - :param params: Additional domain-specific parameters for the predictions. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use to connect to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields: Sequence[str] = ( - "model_id", - "location", - "project_id", - "impersonation_chain", - ) - operator_extra_links = (TranslationLegacyModelLink(),) - - def __init__( - self, - *, - model_id: str, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - metadata: MetaData = (), - timeout: float | None = None, - retry: Retry | _MethodDefault = DEFAULT, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.model_id = model_id - self.location = location - self.project_id = project_id - self.metadata = metadata - self.timeout = timeout - self.retry = retry - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - def execute(self, context: Context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - result = hook.get_model( - model_id=self.model_id, - location=self.location, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - model = Model.to_dict(result) - project_id = self.project_id or hook.project_id - if project_id: - TranslationLegacyModelLink.persist( - context=context, - dataset_id=model["dataset_id"], - model_id=self.model_id, - project_id=project_id, - location=self.location, - ) - return model - - -@deprecated( - planned_removal_date="September 30, 2025", - use_instead="airflow.providers.google.cloud.operators.vertex_ai.model_service.DeleteModelOperator, " - "airflow.providers.google.cloud.operators.translate.TranslateDeleteModelOperator", - category=AirflowProviderDeprecationWarning, -) -class AutoMLDeleteModelOperator(GoogleCloudBaseOperator): - """ - Delete Google Cloud AutoML model. - - .. warning:: - AutoMLDeleteModelOperator for tables, video intelligence, vision and natural language has been deprecated - and no longer available. Please use - :class:`airflow.providers.google.cloud.operators.vertex_ai.model_service.DeleteModelOperator` instead. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:AutoMLDeleteModelOperator` - - :param model_id: Name of the model requested to serve the prediction. - :param project_id: ID of the Google Cloud project where model is located if None then - default project_id is used. - :param location: The location of the project. - :param params: Additional domain-specific parameters for the predictions. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use to connect to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields: Sequence[str] = ( - "model_id", - "location", - "project_id", - "impersonation_chain", - ) - - def __init__( - self, - *, - model_id: str, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - metadata: MetaData = (), - timeout: float | None = None, - retry: Retry | _MethodDefault = DEFAULT, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.model_id = model_id - self.location = location - self.project_id = project_id - self.metadata = metadata - self.timeout = timeout - self.retry = retry - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - def execute(self, context: Context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - hook.get_model( - model_id=self.model_id, - location=self.location, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - operation = hook.delete_model( - model_id=self.model_id, - location=self.location, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - hook.wait_for_operation(timeout=self.timeout, operation=operation) - self.log.info("Deletion is completed") - - -@deprecated( - planned_removal_date="September 30, 2025", - use_instead="airflow.providers.google.cloud.operators.vertex_ai.endpoint_service.DeployModelOperator", - category=AirflowProviderDeprecationWarning, -) -class AutoMLDeployModelOperator(GoogleCloudBaseOperator): - """ - Deploys a model; if a model is already deployed, deploying it with the same parameters has no effect. - - Deploying with different parameters (as e.g. changing node_number) will - reset the deployment state without pausing the model_id's availability. - - Only applicable for Text Classification, Image Object Detection and Tables; all other - domains manage deployment automatically. - - .. warning:: - Operator AutoMLDeployModelOperator has been deprecated due to shutdown of a legacy version - of AutoML Natural Language, Vision, Video Intelligence on March 31, 2024. - For additional information see: https://cloud.google.com/vision/automl/docs/deprecations . - Please use :class:`airflow.providers.google.cloud.operators.vertex_ai.endpoint_service.DeployModelOperator` - instead. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:AutoMLDeployModelOperator` - - :param model_id: Name of the model to be deployed. - :param image_detection_metadata: Model deployment metadata specific to Image Object Detection. - If a dict is provided, it must be of the same form as the protobuf message - ImageObjectDetectionModelDeploymentMetadata - :param project_id: ID of the Google Cloud project where model is located if None then - default project_id is used. - :param location: The location of the project. - :param params: Additional domain-specific parameters for the predictions. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use to connect to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields: Sequence[str] = ( - "model_id", - "location", - "project_id", - "impersonation_chain", - ) - - def __init__( - self, - *, - model_id: str, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - image_detection_metadata: dict | None = None, - metadata: Sequence[tuple[str, str]] = (), - timeout: float | None = None, - retry: Retry | _MethodDefault = DEFAULT, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.model_id = model_id - self.image_detection_metadata = image_detection_metadata - self.location = location - self.project_id = project_id - self.metadata = metadata - self.timeout = timeout - self.retry = retry - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - def execute(self, context: Context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - self.log.info("Deploying model_id %s", self.model_id) - operation = hook.deploy_model( - model_id=self.model_id, - location=self.location, - project_id=self.project_id, - image_detection_metadata=self.image_detection_metadata, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - hook.wait_for_operation(timeout=self.timeout, operation=operation) - self.log.info("Model was deployed successfully.") - - -@deprecated( - planned_removal_date="September 30, 2025", - category=AirflowProviderDeprecationWarning, - reason="Shutdown of legacy version of AutoML Tables on March 31, 2024.", -) -class AutoMLTablesListTableSpecsOperator(GoogleCloudBaseOperator): - """ - Lists table specs in a dataset. - - .. warning:: - Operator AutoMLTablesListTableSpecsOperator has been deprecated due to shutdown of - a legacy version of AutoML Tables on March 31, 2024. For additional information - see: https://cloud.google.com/automl-tables/docs/deprecations. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:AutoMLTablesListTableSpecsOperator` - - :param dataset_id: Name of the dataset. - :param filter_: Filter expression, see go/filtering. - :param page_size: The maximum number of resources contained in the - underlying API response. If page streaming is performed per - resource, this parameter does not affect the return value. If page - streaming is performed per-page, this determines the maximum number - of resources in a page. - :param project_id: ID of the Google Cloud project if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use to connect to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields: Sequence[str] = ( - "dataset_id", - "filter_", - "location", - "project_id", - "impersonation_chain", - ) - operator_extra_links = (TranslationLegacyDatasetLink(),) - - def __init__( - self, - *, - dataset_id: str, - location: str, - page_size: int | None = None, - filter_: str | None = None, - project_id: str = PROVIDE_PROJECT_ID, - metadata: MetaData = (), - timeout: float | None = None, - retry: Retry | _MethodDefault = DEFAULT, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.dataset_id = dataset_id - self.filter_ = filter_ - self.page_size = page_size - self.location = location - self.project_id = project_id - self.metadata = metadata - self.timeout = timeout - self.retry = retry - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - def execute(self, context: Context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - self.log.info("Requesting table specs for %s.", self.dataset_id) - page_iterator = hook.list_table_specs( - dataset_id=self.dataset_id, - filter_=self.filter_, - page_size=self.page_size, - location=self.location, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - result = [TableSpec.to_dict(spec) for spec in page_iterator] - self.log.info(result) - self.log.info("Table specs obtained.") - project_id = self.project_id or hook.project_id - if project_id: - TranslationLegacyDatasetLink.persist( - context=context, - dataset_id=self.dataset_id, - project_id=project_id, - location=self.location, - ) - return result - - -@deprecated( - planned_removal_date="September 30, 2025", - use_instead="airflow.providers.google.cloud.operators.vertex_ai.dataset.ListDatasetsOperator, " - "airflow.providers.google.cloud.operators.translate.TranslateDatasetsListOperator", - category=AirflowProviderDeprecationWarning, -) -class AutoMLListDatasetOperator(GoogleCloudBaseOperator): - """ - Lists AutoML Datasets in project. - - .. warning:: - AutoMLListDatasetOperator for tables, video intelligence, vision and natural language has been deprecated - and no longer available. Please use - :class:`airflow.providers.google.cloud.operators.vertex_ai.dataset.ListDatasetsOperator` instead. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:AutoMLListDatasetOperator` - - :param project_id: ID of the Google Cloud project where datasets are located if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use to connect to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields: Sequence[str] = ( - "location", - "project_id", - "impersonation_chain", - ) - operator_extra_links = (TranslationDatasetListLink(),) - - def __init__( - self, - *, - location: str, - project_id: str = PROVIDE_PROJECT_ID, - metadata: MetaData = (), - timeout: float | None = None, - retry: Retry | _MethodDefault = DEFAULT, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.location = location - self.project_id = project_id - self.metadata = metadata - self.timeout = timeout - self.retry = retry - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - def execute(self, context: Context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - self.log.info("Requesting datasets") - page_iterator = hook.list_datasets( - location=self.location, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - result = [] - for dataset in page_iterator: - result.append(Dataset.to_dict(dataset)) - self.log.info("Datasets obtained.") - - context["task_instance"].xcom_push( - key="dataset_id_list", - value=[hook.extract_object_id(d) for d in result], - ) - project_id = self.project_id or hook.project_id - if project_id: - TranslationDatasetListLink.persist(context=context, project_id=project_id) - return result - - -@deprecated( - planned_removal_date="September 30, 2025", - use_instead="airflow.providers.google.cloud.operators.vertex_ai.dataset.ListDatasetsOperator, " - "airflow.providers.google.cloud.operators.translate.TranslateDatasetsListOperator", - category=AirflowProviderDeprecationWarning, -) -class AutoMLDeleteDatasetOperator(GoogleCloudBaseOperator): - """ - Deletes a dataset and all of its contents. - - AutoMLDeleteDatasetOperator for tables, video intelligence, vision and natural language has been - deprecated and no longer available. Please use - :class:`airflow.providers.google.cloud.operators.vertex_ai.dataset.DeleteDatasetOperator` instead. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:AutoMLDeleteDatasetOperator` - - :param dataset_id: Name of the dataset_id, list of dataset_id or string of dataset_id - coma separated to be deleted. - :param project_id: ID of the Google Cloud project where dataset is located if None then - default project_id is used. - :param location: The location of the project. - :param retry: A retry object used to retry requests. If `None` is specified, requests will not be - retried. - :param timeout: The amount of time, in seconds, to wait for the request to complete. Note that if - `retry` is specified, the timeout applies to each individual attempt. - :param metadata: Additional metadata that is provided to the method. - :param gcp_conn_id: The connection ID to use to connect to Google Cloud. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields: Sequence[str] = ( - "dataset_id", - "location", - "project_id", - "impersonation_chain", - ) - - def __init__( - self, - *, - dataset_id: str | list[str], - location: str, - project_id: str = PROVIDE_PROJECT_ID, - metadata: MetaData = (), - timeout: float | None = None, - retry: Retry | _MethodDefault = DEFAULT, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - - self.dataset_id = dataset_id - self.location = location - self.project_id = project_id - self.metadata = metadata - self.timeout = timeout - self.retry = retry - self.gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - @staticmethod - def _parse_dataset_id(dataset_id: str | list[str]) -> list[str]: - if not isinstance(dataset_id, str): - return dataset_id - try: - return ast.literal_eval(dataset_id) - except (SyntaxError, ValueError): - return dataset_id.split(",") - - def execute(self, context: Context): - hook = CloudAutoMLHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - hook.get_dataset( - dataset_id=self.dataset_id, - location=self.location, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - dataset_id_list = self._parse_dataset_id(self.dataset_id) - for dataset_id in dataset_id_list: - self.log.info("Deleting dataset %s", dataset_id) - hook.delete_dataset( - dataset_id=dataset_id, - location=self.location, - project_id=self.project_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - self.log.info("Dataset deleted.") diff --git a/providers/google/src/airflow/providers/google/cloud/operators/mlengine.py b/providers/google/src/airflow/providers/google/cloud/operators/mlengine.py deleted file mode 100644 index 2fa2c75fed899..0000000000000 --- a/providers/google/src/airflow/providers/google/cloud/operators/mlengine.py +++ /dev/null @@ -1,111 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""This module contains Google Cloud MLEngine operators.""" - -from __future__ import annotations - -import logging -from collections.abc import Sequence -from typing import TYPE_CHECKING - -from airflow.exceptions import AirflowProviderDeprecationWarning -from airflow.providers.google.cloud.hooks.mlengine import MLEngineHook -from airflow.providers.google.cloud.links.mlengine import MLEngineModelLink -from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator -from airflow.providers.google.common.deprecated import deprecated -from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID - -if TYPE_CHECKING: - from airflow.providers.common.compat.sdk import Context - - -log = logging.getLogger(__name__) - - -@deprecated( - planned_removal_date="November 01, 2025", - use_instead="appropriate VertexAI operator", - reason="All the functionality of legacy MLEngine and new features are available on the Vertex AI.", - category=AirflowProviderDeprecationWarning, -) -class MLEngineCreateModelOperator(GoogleCloudBaseOperator): - """ - Creates a new model. - - .. warning:: - This operator is deprecated. Please use appropriate VertexAI operator from - :class:`airflow.providers.google.cloud.operators.vertex_ai` instead. - - .. seealso:: - For more information on how to use this operator, take a look at the guide: - :ref:`howto/operator:MLEngineCreateModelOperator` - - The model should be provided by the `model` parameter. - - :param model: A dictionary containing the information about the model. - :param project_id: The Google Cloud project name to which MLEngine model belongs. - If set to None or missing, the default project_id from the Google Cloud connection is used. - (templated) - :param gcp_conn_id: The connection ID to use when fetching connection info. - :param impersonation_chain: Optional service account to impersonate using short-term - credentials, or chained list of accounts required to get the access_token - of the last account in the list, which will be impersonated in the request. - If set as a string, the account must grant the originating account - the Service Account Token Creator IAM role. - If set as a sequence, the identities from the list must grant - Service Account Token Creator IAM role to the directly preceding identity, with first - account from the list granting this role to the originating account (templated). - """ - - template_fields: Sequence[str] = ( - "project_id", - "model", - "impersonation_chain", - ) - operator_extra_links = (MLEngineModelLink(),) - - def __init__( - self, - *, - model: dict, - project_id: str = PROVIDE_PROJECT_ID, - gcp_conn_id: str = "google_cloud_default", - impersonation_chain: str | Sequence[str] | None = None, - **kwargs, - ) -> None: - super().__init__(**kwargs) - self.project_id = project_id - self.model = model - self._gcp_conn_id = gcp_conn_id - self.impersonation_chain = impersonation_chain - - def execute(self, context: Context): - hook = MLEngineHook( - gcp_conn_id=self._gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) - - project_id = self.project_id or hook.project_id - if project_id: - MLEngineModelLink.persist( - context=context, - project_id=project_id, - model_id=self.model["name"], - ) - - return hook.create_model(project_id=self.project_id, model=self.model) diff --git a/providers/google/src/airflow/providers/google/get_provider_info.py b/providers/google/src/airflow/providers/google/get_provider_info.py index 6f585a2278fc8..12d473d44a2e0 100644 --- a/providers/google/src/airflow/providers/google/get_provider_info.py +++ b/providers/google/src/airflow/providers/google/get_provider_info.py @@ -43,13 +43,6 @@ def get_provider_info(): "how-to-guide": ["/docs/apache-airflow-providers-google/operators/ads.rst"], "tags": ["gmp"], }, - { - "integration-name": "Google AutoML", - "external-doc-url": "https://cloud.google.com/automl/", - "how-to-guide": ["/docs/apache-airflow-providers-google/operators/cloud/automl.rst"], - "logo": "/docs/integration-logos/Cloud-AutoML.png", - "tags": ["gcp"], - }, { "integration-name": "Google BigQuery Data Transfer Service", "external-doc-url": "https://cloud.google.com/bigquery/transfer/", @@ -495,10 +488,6 @@ def get_provider_info(): "integration-name": "Google Cloud AlloyDB", "python-modules": ["airflow.providers.google.cloud.operators.alloy_db"], }, - { - "integration-name": "Google AutoML", - "python-modules": ["airflow.providers.google.cloud.operators.automl"], - }, { "integration-name": "Google BigQuery", "python-modules": ["airflow.providers.google.cloud.operators.bigquery"], @@ -583,10 +572,6 @@ def get_provider_info(): "integration-name": "Google Kubernetes Engine", "python-modules": ["airflow.providers.google.cloud.operators.kubernetes_engine"], }, - { - "integration-name": "Google Machine Learning Engine", - "python-modules": ["airflow.providers.google.cloud.operators.mlengine"], - }, { "integration-name": "Google Cloud Natural Language", "python-modules": ["airflow.providers.google.cloud.operators.natural_language"], @@ -826,10 +811,6 @@ def get_provider_info(): ], "hooks": [ {"integration-name": "Google Ads", "python-modules": ["airflow.providers.google.ads.hooks.ads"]}, - { - "integration-name": "Google AutoML", - "python-modules": ["airflow.providers.google.cloud.hooks.automl"], - }, { "integration-name": "Google BigQuery", "python-modules": ["airflow.providers.google.cloud.hooks.bigquery"], diff --git a/providers/google/tests/deprecations_ignore.yml b/providers/google/tests/deprecations_ignore.yml index 66a1508a6ec52..d07c6bdfc21a0 100644 --- a/providers/google/tests/deprecations_ignore.yml +++ b/providers/google/tests/deprecations_ignore.yml @@ -80,26 +80,6 @@ - providers/google/tests/unit/google/cloud/transfers/test_gcs_to_gcs.py::TestGoogleCloudStorageToCloudStorageOperator::test_wc_with_last_modified_time_with_all_true_cond - providers/google/tests/unit/google/cloud/transfers/test_gcs_to_gcs.py::TestGoogleCloudStorageToCloudStorageOperator::test_wc_with_last_modified_time_with_one_true_cond - providers/google/tests/unit/google/cloud/transfers/test_gcs_to_gcs.py::TestGoogleCloudStorageToCloudStorageOperator::test_wc_with_no_last_modified_time -- providers/google/tests/unit/google/cloud/links/test_translate.py::TestTranslationLegacyDatasetLink::test_get_link -- providers/google/tests/unit/google/cloud/links/test_translate.py::TestTranslationDatasetListLink::test_get_link -- providers/google/tests/unit/google/cloud/links/test_translate.py::TestTranslationLegacyModelLink::test_get_link -- providers/google/tests/unit/google/cloud/links/test_translate.py::TestTranslationLegacyModelTrainLink::test_get_link -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_get_conn -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_prediction_client -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_create_model -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_batch_predict -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_predict -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_create_dataset -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_import_dataset -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_list_column_specs -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_get_model -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_delete_model -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_update_dataset -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_deploy_model -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_list_table_specs -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_list_datasets -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_delete_dataset -- providers/google/tests/unit/google/cloud/hooks/test_automl.py::TestAutoMLHook::test_get_dataset - providers/google/tests/unit/google/cloud/hooks/test_datacatalog.py::TestCloudDataCatalog::test_lookup_entry_with_linked_resource - providers/google/tests/unit/google/cloud/hooks/test_datacatalog.py::TestCloudDataCatalog::test_lookup_entry_with_sql_resource - providers/google/tests/unit/google/cloud/hooks/test_datacatalog.py::TestCloudDataCatalog::test_lookup_entry_without_resource diff --git a/providers/google/tests/unit/google/cloud/hooks/test_automl.py b/providers/google/tests/unit/google/cloud/hooks/test_automl.py deleted file mode 100644 index ec49f23df9e2c..0000000000000 --- a/providers/google/tests/unit/google/cloud/hooks/test_automl.py +++ /dev/null @@ -1,257 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from unittest import mock - -from google.api_core.gapic_v1.method import DEFAULT -from google.cloud.automl_v1beta1 import AutoMlClient - -from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook -from airflow.providers.google.common.consts import CLIENT_INFO - -from unit.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_no_default_project_id - -CREDENTIALS = "test-creds" -TASK_ID = "test-automl-hook" -GCP_PROJECT_ID = "test-project" -GCP_LOCATION = "test-location" -MODEL_NAME = "test_model" -MODEL_ID = "projects/198907790164/locations/us-central1/models/TBL9195602771183665152" -DATASET_ID = "TBL123456789" -MODEL = { - "display_name": MODEL_NAME, - "dataset_id": DATASET_ID, - "tables_model_metadata": {"train_budget_milli_node_hours": 1000}, -} - -LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}" -MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}" -DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}" - -INPUT_CONFIG = {"input": "value"} -OUTPUT_CONFIG = {"output": "value"} -PAYLOAD = {"test": "payload"} -DATASET = {"dataset_id": "data"} -MASK = {"field": "mask"} - - -class TestAutoMLHook: - def setup_method(self): - with mock.patch( - "airflow.providers.google.cloud.hooks.automl.GoogleBaseHook.__init__", - new=mock_base_gcp_hook_no_default_project_id, - ): - self.hook = CloudAutoMLHook() - self.hook.get_credentials = mock.MagicMock(return_value=CREDENTIALS) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient") - def test_get_conn(self, mock_automl_client): - self.hook.get_conn() - mock_automl_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient") - def test_prediction_client(self, mock_prediction_client): - client = self.hook.prediction_client # noqa: F841 - mock_prediction_client.assert_called_once_with(credentials=CREDENTIALS, client_info=CLIENT_INFO) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_model") - def test_create_model(self, mock_create_model): - self.hook.create_model(model=MODEL, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) - - mock_create_model.assert_called_once_with( - request=dict(parent=LOCATION_PATH, model=MODEL), retry=DEFAULT, timeout=None, metadata=() - ) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.batch_predict") - def test_batch_predict(self, mock_batch_predict): - self.hook.batch_predict( - model_id=MODEL_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - input_config=INPUT_CONFIG, - output_config=OUTPUT_CONFIG, - ) - - mock_batch_predict.assert_called_once_with( - request=dict( - name=MODEL_PATH, input_config=INPUT_CONFIG, output_config=OUTPUT_CONFIG, params=None - ), - retry=DEFAULT, - timeout=None, - metadata=(), - ) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.PredictionServiceClient.predict") - def test_predict(self, mock_predict): - self.hook.predict( - model_id=MODEL_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - payload=PAYLOAD, - ) - - mock_predict.assert_called_once_with( - request=dict(name=MODEL_PATH, payload=PAYLOAD, params=None), - retry=DEFAULT, - timeout=None, - metadata=(), - ) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.create_dataset") - def test_create_dataset(self, mock_create_dataset): - self.hook.create_dataset(dataset=DATASET, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) - - mock_create_dataset.assert_called_once_with( - request=dict(parent=LOCATION_PATH, dataset=DATASET), - retry=DEFAULT, - timeout=None, - metadata=(), - ) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.import_data") - def test_import_dataset(self, mock_import_data): - self.hook.import_data( - dataset_id=DATASET_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - input_config=INPUT_CONFIG, - ) - - mock_import_data.assert_called_once_with( - request=dict(name=DATASET_PATH, input_config=INPUT_CONFIG), - retry=DEFAULT, - timeout=None, - metadata=(), - ) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_column_specs") - def test_list_column_specs(self, mock_list_column_specs): - table_spec = "table_spec_id" - filter_ = "filter" - page_size = 42 - - self.hook.list_column_specs( - dataset_id=DATASET_ID, - table_spec_id=table_spec, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - field_mask=MASK, - filter_=filter_, - page_size=page_size, - ) - - parent = AutoMlClient.table_spec_path(GCP_PROJECT_ID, GCP_LOCATION, DATASET_ID, table_spec) - mock_list_column_specs.assert_called_once_with( - request=dict(parent=parent, field_mask=MASK, filter=filter_, page_size=page_size), - retry=DEFAULT, - timeout=None, - metadata=(), - ) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.get_model") - def test_get_model(self, mock_get_model): - self.hook.get_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) - - mock_get_model.assert_called_once_with( - request=dict(name=MODEL_PATH), retry=DEFAULT, timeout=None, metadata=() - ) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_model") - def test_delete_model(self, mock_delete_model): - self.hook.delete_model(model_id=MODEL_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) - - mock_delete_model.assert_called_once_with( - request=dict(name=MODEL_PATH), retry=DEFAULT, timeout=None, metadata=() - ) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.update_dataset") - def test_update_dataset(self, mock_update_dataset): - self.hook.update_dataset( - dataset=DATASET, - update_mask=MASK, - ) - - mock_update_dataset.assert_called_once_with( - request=dict(dataset=DATASET, update_mask=MASK), retry=DEFAULT, timeout=None, metadata=() - ) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.deploy_model") - def test_deploy_model(self, mock_deploy_model): - image_detection_metadata = {} - - self.hook.deploy_model( - model_id=MODEL_ID, - image_detection_metadata=image_detection_metadata, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - ) - - mock_deploy_model.assert_called_once_with( - request=dict( - name=MODEL_PATH, - image_object_detection_model_deployment_metadata=image_detection_metadata, - ), - retry=DEFAULT, - timeout=None, - metadata=(), - ) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_table_specs") - def test_list_table_specs(self, mock_list_table_specs): - filter_ = "filter" - page_size = 42 - - self.hook.list_table_specs( - dataset_id=DATASET_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - filter_=filter_, - page_size=page_size, - ) - - mock_list_table_specs.assert_called_once_with( - request=dict(parent=DATASET_PATH, filter=filter_, page_size=page_size), - retry=DEFAULT, - timeout=None, - metadata=(), - ) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.list_datasets") - def test_list_datasets(self, mock_list_datasets): - self.hook.list_datasets(location=GCP_LOCATION, project_id=GCP_PROJECT_ID) - - mock_list_datasets.assert_called_once_with( - request=dict(parent=LOCATION_PATH), retry=DEFAULT, timeout=None, metadata=() - ) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.delete_dataset") - def test_delete_dataset(self, mock_delete_dataset): - self.hook.delete_dataset(dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) - - mock_delete_dataset.assert_called_once_with( - request=dict(name=DATASET_PATH), retry=DEFAULT, timeout=None, metadata=() - ) - - @mock.patch("airflow.providers.google.cloud.hooks.automl.AutoMlClient.get_dataset") - def test_get_dataset(self, mock_get_dataset): - self.hook.get_dataset(dataset_id=DATASET_ID, location=GCP_LOCATION, project_id=GCP_PROJECT_ID) - - mock_get_dataset.assert_called_once_with( - request=dict(name=DATASET_PATH), retry=DEFAULT, timeout=None, metadata=() - ) diff --git a/providers/google/tests/unit/google/cloud/links/test_translate.py b/providers/google/tests/unit/google/cloud/links/test_translate.py deleted file mode 100644 index 0adc28e53e218..0000000000000 --- a/providers/google/tests/unit/google/cloud/links/test_translate.py +++ /dev/null @@ -1,198 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from unittest import mock - -import pytest - -from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS - -# For no Pydantic environment, we need to skip the tests -pytest.importorskip("google.cloud.aiplatform_v1") - -from airflow.providers.google.cloud.links.translate import ( - TRANSLATION_BASE_LINK, - TranslationDatasetListLink, - TranslationLegacyDatasetLink, - TranslationLegacyModelLink, - TranslationLegacyModelTrainLink, -) -from airflow.providers.google.cloud.operators.automl import ( - AutoMLCreateDatasetOperator, - AutoMLListDatasetOperator, - AutoMLTrainModelOperator, -) - -if AIRFLOW_V_3_0_PLUS: - from airflow.sdk.execution_time.comms import XComResult - -GCP_LOCATION = "test-location" -GCP_PROJECT_ID = "test-project" -DATASET = "test-dataset" -MODEL = "test-model" - - -class TestTranslationLegacyDatasetLink: - @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): - expected_url = f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/sentences?project={GCP_PROJECT_ID}" - link = TranslationLegacyDatasetLink() - ti = create_task_instance_of_operator( - AutoMLCreateDatasetOperator, - dag_id="test_legacy_dataset_link_dag", - task_id="test_legacy_dataset_link_task", - dataset=DATASET, - location=GCP_LOCATION, - ) - session.add(ti) - session.commit() - - link.persist( - context={"ti": ti, "task": ti.task}, - dataset_id=DATASET, - project_id=GCP_PROJECT_ID, - location=GCP_LOCATION, - ) - - if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.send.return_value = XComResult( - key="key", - value={ - "location": ti.task.location, - "dataset_id": DATASET, - "project_id": GCP_PROJECT_ID, - }, - ) - actual_url = link.get_link(operator=ti.task, ti_key=ti.key) - assert actual_url == expected_url - - -class TestTranslationDatasetListLink: - @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): - expected_url = f"{TRANSLATION_BASE_LINK}/datasets?project={GCP_PROJECT_ID}" - link = TranslationDatasetListLink() - ti = create_task_instance_of_operator( - AutoMLListDatasetOperator, - dag_id="test_dataset_list_link_dag", - task_id="test_dataset_list_link_task", - location=GCP_LOCATION, - ) - session.add(ti) - session.commit() - link.persist(context={"ti": ti, "task": ti.task}, project_id=GCP_PROJECT_ID) - - if AIRFLOW_V_3_0_PLUS and mock_supervisor_comms: - mock_supervisor_comms.send.return_value = XComResult( - key="key", - value={ - "project_id": GCP_PROJECT_ID, - }, - ) - actual_url = link.get_link(operator=ti.task, ti_key=ti.key) - assert actual_url == expected_url - - -class TestTranslationLegacyModelLink: - @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): - expected_url = ( - f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/" - f"evaluate;modelId={MODEL}?project={GCP_PROJECT_ID}" - ) - link = TranslationLegacyModelLink() - ti = create_task_instance_of_operator( - AutoMLTrainModelOperator, - dag_id="test_legacy_model_link_dag", - task_id="test_legacy_model_link_task", - model=MODEL, - project_id=GCP_PROJECT_ID, - location=GCP_LOCATION, - ) - session.add(ti) - session.commit() - task = mock.MagicMock() - task.extra_links_params = { - "dataset_id": DATASET, - "model_id": MODEL, - "project_id": GCP_PROJECT_ID, - "location": GCP_LOCATION, - } - link.persist( - context={"ti": ti, "task": task}, - dataset_id=DATASET, - model_id=MODEL, - project_id=GCP_PROJECT_ID, - ) - if mock_supervisor_comms: - mock_supervisor_comms.send.return_value = XComResult( - key="key", - value={ - "location": ti.task.location, - "dataset_id": DATASET, - "model_id": MODEL, - "project_id": GCP_PROJECT_ID, - }, - ) - actual_url = link.get_link(operator=task, ti_key=ti.key) - assert actual_url == expected_url - - -class TestTranslationLegacyModelTrainLink: - @pytest.mark.db_test - def test_get_link(self, create_task_instance_of_operator, session, mock_supervisor_comms): - expected_url = ( - f"{TRANSLATION_BASE_LINK}/locations/{GCP_LOCATION}/datasets/{DATASET}/" - f"train?project={GCP_PROJECT_ID}" - ) - link = TranslationLegacyModelTrainLink() - ti = create_task_instance_of_operator( - AutoMLTrainModelOperator, - dag_id="test_legacy_model_train_link_dag", - task_id="test_legacy_model_train_link_task", - model={"dataset_id": DATASET}, - project_id=GCP_PROJECT_ID, - location=GCP_LOCATION, - ) - session.add(ti) - session.commit() - task = mock.MagicMock() - task.extra_links_params = { - "dataset_id": DATASET, - "project_id": GCP_PROJECT_ID, - "location": GCP_LOCATION, - } - link.persist( - context={"ti": ti, "task": task}, - dataset_id=DATASET, - project_id=GCP_PROJECT_ID, - location=GCP_LOCATION, - ) - - if mock_supervisor_comms: - mock_supervisor_comms.send.return_value = XComResult( - key="key", - value={ - "location": ti.task.location, - "dataset_id": ti.task.model["dataset_id"], - "project_id": GCP_PROJECT_ID, - }, - ) - actual_url = link.get_link(operator=task, ti_key=ti.key) - assert actual_url == expected_url diff --git a/providers/google/tests/unit/google/cloud/operators/test_automl.py b/providers/google/tests/unit/google/cloud/operators/test_automl.py deleted file mode 100644 index 18f4de4af709d..0000000000000 --- a/providers/google/tests/unit/google/cloud/operators/test_automl.py +++ /dev/null @@ -1,616 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -import copy -from unittest import mock - -import pytest - -# For no Pydantic environment, we need to skip the tests -pytest.importorskip("google.cloud.aiplatform_v1") - -from google.api_core.gapic_v1.method import DEFAULT -from google.cloud.automl_v1beta1 import Dataset, Model, PredictResponse - -from airflow.exceptions import AirflowProviderDeprecationWarning -from airflow.providers.google.cloud.hooks.automl import CloudAutoMLHook -from airflow.providers.google.cloud.hooks.vertex_ai.prediction_service import PredictionServiceHook -from airflow.providers.google.cloud.operators.automl import ( - AutoMLCreateDatasetOperator, - AutoMLDeleteDatasetOperator, - AutoMLDeleteModelOperator, - AutoMLDeployModelOperator, - AutoMLGetModelOperator, - AutoMLImportDataOperator, - AutoMLListDatasetOperator, - AutoMLPredictOperator, - AutoMLTablesListColumnSpecsOperator, - AutoMLTablesListTableSpecsOperator, - AutoMLTablesUpdateDatasetOperator, - AutoMLTrainModelOperator, -) - -CREDENTIALS = "test-creds" -TASK_ID = "test-automl-hook" -GCP_PROJECT_ID = "test-project" -GCP_LOCATION = "test-location" -MODEL_NAME = "test_model" -MODEL_ID = "TBL9195602771183665152" -DATASET_ID = "TBL123456789" -MODEL = { - "display_name": MODEL_NAME, - "dataset_id": DATASET_ID, - "translation_model_metadata": {"train_budget_milli_node_hours": 1000}, -} -MODEL_DEPRECATED = { - "display_name": MODEL_NAME, - "dataset_id": DATASET_ID, - "tables_model_metadata": {"train_budget_milli_node_hours": 1000}, -} - -LOCATION_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}" -MODEL_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/models/{MODEL_ID}" -DATASET_PATH = f"projects/{GCP_PROJECT_ID}/locations/{GCP_LOCATION}/datasets/{DATASET_ID}" - -INPUT_CONFIG = {"input": "value"} -OUTPUT_CONFIG = {"output": "value"} -PAYLOAD = {"test": "payload"} -DATASET = {"dataset_id": "data", "translation_dataset_metadata": "data"} -DATASET_DEPRECATED = {"tables_model_metadata": "data"} -MASK = {"field": "mask"} - -extract_object_id = CloudAutoMLHook.extract_object_id - - -class TestAutoMLTrainModelOperator: - @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook): - mock_hook.return_value.create_model.return_value.result.return_value = Model(name=MODEL_PATH) - mock_hook.return_value.extract_object_id = extract_object_id - mock_hook.return_value.wait_for_operation.return_value = Model() - with pytest.warns(AirflowProviderDeprecationWarning): - op = AutoMLTrainModelOperator( - model=MODEL, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - task_id=TASK_ID, - ) - op.execute(context=mock.MagicMock()) - - mock_hook.return_value.create_model.assert_called_once_with( - model=MODEL, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - retry=DEFAULT, - timeout=None, - metadata=(), - ) - - @pytest.mark.db_test - def test_templating(self, create_task_instance_of_operator, session): - with pytest.warns(AirflowProviderDeprecationWarning): - ti = create_task_instance_of_operator( - AutoMLTrainModelOperator, - # Templated fields - model="{{ 'model' }}", - location="{{ 'location' }}", - impersonation_chain="{{ 'impersonation_chain' }}", - # Other parameters - dag_id="test_template_body_templating_dag", - task_id="test_template_body_templating_task", - ) - session.add(ti) - session.commit() - ti.render_templates() - task: AutoMLTrainModelOperator = ti.task - assert task.model == "model" - assert task.location == "location" - assert task.impersonation_chain == "impersonation_chain" - - -class TestAutoMLPredictOperator: - @mock.patch("airflow.providers.google.cloud.links.translate.TranslationLegacyModelPredictLink.persist") - @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook, mock_link_persist): - mock_hook.return_value.predict.return_value = PredictResponse() - mock_hook.return_value.get_model.return_value = mock.MagicMock(**MODEL) - mock_context = {"ti": mock.MagicMock()} - with pytest.warns(AirflowProviderDeprecationWarning): - op = AutoMLPredictOperator( - model_id=MODEL_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - payload=PAYLOAD, - task_id=TASK_ID, - operation_params={"TEST_KEY": "TEST_VALUE"}, - ) - op.execute(context=mock_context) - mock_hook.return_value.predict.assert_called_once_with( - location=GCP_LOCATION, - metadata=(), - model_id=MODEL_ID, - params={"TEST_KEY": "TEST_VALUE"}, - payload=PAYLOAD, - project_id=GCP_PROJECT_ID, - retry=DEFAULT, - timeout=None, - ) - mock_link_persist.assert_called_once_with( - context=mock_context, - model_id=MODEL_ID, - project_id=GCP_PROJECT_ID, - location=GCP_LOCATION, - dataset_id=DATASET_ID, - ) - - @pytest.mark.db_test - def test_templating(self, create_task_instance_of_operator, session): - with pytest.warns(AirflowProviderDeprecationWarning): - ti = create_task_instance_of_operator( - AutoMLPredictOperator, - # Templated fields - model_id="{{ 'model-id' }}", - location="{{ 'location' }}", - project_id="{{ 'project-id' }}", - impersonation_chain="{{ 'impersonation-chain' }}", - # Other parameters - dag_id="test_template_body_templating_dag", - task_id="test_template_body_templating_task", - payload={}, - ) - session.add(ti) - session.commit() - ti.render_templates() - task: AutoMLPredictOperator = ti.task - assert task.model_id == "model-id" - assert task.project_id == "project-id" - assert task.location == "location" - assert task.impersonation_chain == "impersonation-chain" - - @pytest.mark.db_test - def test_hook_type(self): - with pytest.warns(AirflowProviderDeprecationWarning): - op = AutoMLPredictOperator( - model_id=MODEL_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - payload=PAYLOAD, - task_id=TASK_ID, - operation_params={"TEST_KEY": "TEST_VALUE"}, - ) - with pytest.warns(AirflowProviderDeprecationWarning): - assert isinstance(op.hook, CloudAutoMLHook) - with pytest.warns(AirflowProviderDeprecationWarning): - op = AutoMLPredictOperator( - endpoint_id="endpoint_id", - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - payload=PAYLOAD, - task_id=TASK_ID, - operation_params={"TEST_KEY": "TEST_VALUE"}, - ) - assert isinstance(op.hook, PredictionServiceHook) - - -class TestAutoMLCreateImportOperator: - @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook): - mock_hook.return_value.create_dataset.return_value = Dataset(name=DATASET_PATH) - mock_hook.return_value.extract_object_id = extract_object_id - with pytest.warns(AirflowProviderDeprecationWarning): - op = AutoMLCreateDatasetOperator( - dataset=DATASET, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - task_id=TASK_ID, - ) - op.execute(context=mock.MagicMock()) - - mock_hook.return_value.create_dataset.assert_called_once_with( - dataset=DATASET, - location=GCP_LOCATION, - metadata=(), - project_id=GCP_PROJECT_ID, - retry=DEFAULT, - timeout=None, - ) - - @pytest.mark.db_test - def test_templating(self, create_task_instance_of_operator, session): - with pytest.warns(AirflowProviderDeprecationWarning): - ti = create_task_instance_of_operator( - AutoMLCreateDatasetOperator, - # Templated fields - dataset="{{ 'dataset' }}", - location="{{ 'location' }}", - project_id="{{ 'project-id' }}", - impersonation_chain="{{ 'impersonation-chain' }}", - # Other parameters - dag_id="test_template_body_templating_dag", - task_id="test_template_body_templating_task", - ) - session.add(ti) - session.commit() - ti.render_templates() - task: AutoMLCreateDatasetOperator = ti.task - assert task.dataset == "dataset" - assert task.project_id == "project-id" - assert task.location == "location" - assert task.impersonation_chain == "impersonation-chain" - - -class TestAutoMLTablesListColumnsSpecsOperator: - @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook): - table_spec = "table_spec_id" - filter_ = "filter" - page_size = 42 - - with pytest.warns(AirflowProviderDeprecationWarning): - AutoMLTablesListColumnSpecsOperator( - dataset_id=DATASET_ID, - table_spec_id=table_spec, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - field_mask=MASK, - filter_=filter_, - page_size=page_size, - task_id=TASK_ID, - ) - mock_hook.assert_not_called() - - @pytest.mark.db_test - def test_templating(self, create_task_instance_of_operator): - with pytest.warns(AirflowProviderDeprecationWarning): - create_task_instance_of_operator( - AutoMLTablesListColumnSpecsOperator, - # Templated fields - dataset_id="{{ 'dataset-id' }}", - table_spec_id="{{ 'table-spec-id' }}", - field_mask="{{ 'field-mask' }}", - filter_="{{ 'filter-' }}", - location="{{ 'location' }}", - project_id="{{ 'project-id' }}", - impersonation_chain="{{ 'impersonation-chain' }}", - # Other parameters - dag_id="test_template_body_templating_dag", - task_id="test_template_body_templating_task", - ) - - -class TestAutoMLTablesUpdateDatasetOperator: - @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook): - mock_hook.return_value.update_dataset.return_value = Dataset(name=DATASET_PATH) - dataset = copy.deepcopy(DATASET) - dataset["name"] = DATASET_ID - - with pytest.warns(AirflowProviderDeprecationWarning): - AutoMLTablesUpdateDatasetOperator( - dataset=dataset, - update_mask=MASK, - location=GCP_LOCATION, - task_id=TASK_ID, - ) - mock_hook.assert_not_called() - - @pytest.mark.db_test - def test_templating(self, create_task_instance_of_operator): - with pytest.warns(AirflowProviderDeprecationWarning): - create_task_instance_of_operator( - AutoMLTablesUpdateDatasetOperator, - # Templated fields - dataset="{{ 'dataset' }}", - update_mask="{{ 'update-mask' }}", - location="{{ 'location' }}", - impersonation_chain="{{ 'impersonation-chain' }}", - # Other parameters - dag_id="test_template_body_templating_dag", - task_id="test_template_body_templating_task", - ) - - -class TestAutoMLGetModelOperator: - @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook): - mock_hook.return_value.get_model.return_value = Model(name=MODEL_PATH) - mock_hook.return_value.extract_object_id = extract_object_id - with pytest.warns(AirflowProviderDeprecationWarning): - op = AutoMLGetModelOperator( - model_id=MODEL_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - task_id=TASK_ID, - ) - - op.execute(context=mock.MagicMock()) - - mock_hook.return_value.get_model.assert_called_once_with( - location=GCP_LOCATION, - metadata=(), - model_id=MODEL_ID, - project_id=GCP_PROJECT_ID, - retry=DEFAULT, - timeout=None, - ) - - @pytest.mark.db_test - def test_templating(self, create_task_instance_of_operator, session): - with pytest.warns(AirflowProviderDeprecationWarning): - ti = create_task_instance_of_operator( - AutoMLGetModelOperator, - # Templated fields - model_id="{{ 'model-id' }}", - location="{{ 'location' }}", - project_id="{{ 'project-id' }}", - impersonation_chain="{{ 'impersonation-chain' }}", - # Other parameters - dag_id="test_template_body_templating_dag", - task_id="test_template_body_templating_task", - ) - session.add(ti) - session.commit() - ti.render_templates() - task: AutoMLGetModelOperator = ti.task - assert task.model_id == "model-id" - assert task.location == "location" - assert task.project_id == "project-id" - assert task.impersonation_chain == "impersonation-chain" - - -class TestAutoMLDeleteModelOperator: - @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook): - with pytest.warns(AirflowProviderDeprecationWarning): - op = AutoMLDeleteModelOperator( - model_id=MODEL_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - task_id=TASK_ID, - ) - op.execute(context=None) - mock_hook.return_value.delete_model.assert_called_once_with( - location=GCP_LOCATION, - metadata=(), - model_id=MODEL_ID, - project_id=GCP_PROJECT_ID, - retry=DEFAULT, - timeout=None, - ) - - @pytest.mark.db_test - def test_templating(self, create_task_instance_of_operator, session): - with pytest.warns(AirflowProviderDeprecationWarning): - ti = create_task_instance_of_operator( - AutoMLDeleteModelOperator, - # Templated fields - model_id="{{ 'model-id' }}", - location="{{ 'location' }}", - project_id="{{ 'project-id' }}", - impersonation_chain="{{ 'impersonation-chain' }}", - # Other parameters - dag_id="test_template_body_templating_dag", - task_id="test_template_body_templating_task", - ) - session.add(ti) - session.commit() - ti.render_templates() - task: AutoMLDeleteModelOperator = ti.task - assert task.model_id == "model-id" - assert task.location == "location" - assert task.project_id == "project-id" - assert task.impersonation_chain == "impersonation-chain" - - -class TestAutoMLDeployModelOperator: - @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook): - image_detection_metadata = {} - - with pytest.warns(AirflowProviderDeprecationWarning): - AutoMLDeployModelOperator( - model_id=MODEL_ID, - image_detection_metadata=image_detection_metadata, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - task_id=TASK_ID, - ) - - mock_hook.assert_not_called() - - @pytest.mark.db_test - def test_templating(self, create_task_instance_of_operator): - with pytest.warns(AirflowProviderDeprecationWarning): - create_task_instance_of_operator( - AutoMLDeployModelOperator, - # Templated fields - model_id="{{ 'model-id' }}", - location="{{ 'location' }}", - project_id="{{ 'project-id' }}", - impersonation_chain="{{ 'impersonation-chain' }}", - # Other parameters - dag_id="test_template_body_templating_dag", - task_id="test_template_body_templating_task", - ) - - -class TestAutoMLDatasetImportOperator: - @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook): - with pytest.warns(AirflowProviderDeprecationWarning): - op = AutoMLImportDataOperator( - dataset_id=DATASET_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - input_config=INPUT_CONFIG, - task_id=TASK_ID, - ) - - op.execute(context=mock.MagicMock()) - - mock_hook.return_value.import_data.assert_called_once_with( - input_config=INPUT_CONFIG, - location=GCP_LOCATION, - metadata=(), - dataset_id=DATASET_ID, - project_id=GCP_PROJECT_ID, - retry=DEFAULT, - timeout=None, - ) - - @pytest.mark.db_test - def test_templating(self, create_task_instance_of_operator, session): - with pytest.warns(AirflowProviderDeprecationWarning): - ti = create_task_instance_of_operator( - AutoMLImportDataOperator, - # Templated fields - dataset_id="{{ 'dataset-id' }}", - input_config="{{ 'input-config' }}", - location="{{ 'location' }}", - project_id="{{ 'project-id' }}", - impersonation_chain="{{ 'impersonation-chain' }}", - # Other parameters - dag_id="test_template_body_templating_dag", - task_id="test_template_body_templating_task", - ) - session.add(ti) - session.commit() - ti.render_templates() - task: AutoMLImportDataOperator = ti.task - assert task.dataset_id == "dataset-id" - assert task.input_config == "input-config" - assert task.location == "location" - assert task.project_id == "project-id" - assert task.impersonation_chain == "impersonation-chain" - - -class TestAutoMLTablesListTableSpecsOperator: - @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook): - filter_ = "filter" - page_size = 42 - - with pytest.warns(AirflowProviderDeprecationWarning): - AutoMLTablesListTableSpecsOperator( - dataset_id=DATASET_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - filter_=filter_, - page_size=page_size, - task_id=TASK_ID, - ) - mock_hook.assert_not_called() - - @pytest.mark.db_test - def test_templating(self, create_task_instance_of_operator): - with pytest.warns(AirflowProviderDeprecationWarning): - create_task_instance_of_operator( - AutoMLTablesListTableSpecsOperator, - # Templated fields - dataset_id="{{ 'dataset-id' }}", - filter_="{{ 'filter-' }}", - location="{{ 'location' }}", - project_id="{{ 'project-id' }}", - impersonation_chain="{{ 'impersonation-chain' }}", - # Other parameters - dag_id="test_template_body_templating_dag", - task_id="test_template_body_templating_task", - ) - - -class TestAutoMLDatasetListOperator: - @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook): - with pytest.warns(AirflowProviderDeprecationWarning): - op = AutoMLListDatasetOperator(location=GCP_LOCATION, project_id=GCP_PROJECT_ID, task_id=TASK_ID) - - op.execute(context=mock.MagicMock()) - - mock_hook.return_value.list_datasets.assert_called_once_with( - location=GCP_LOCATION, - metadata=(), - project_id=GCP_PROJECT_ID, - retry=DEFAULT, - timeout=None, - ) - - @pytest.mark.db_test - def test_templating(self, create_task_instance_of_operator, session): - with pytest.warns(AirflowProviderDeprecationWarning): - ti = create_task_instance_of_operator( - AutoMLListDatasetOperator, - # Templated fields - location="{{ 'location' }}", - project_id="{{ 'project-id' }}", - impersonation_chain="{{ 'impersonation-chain' }}", - # Other parameters - dag_id="test_template_body_templating_dag", - task_id="test_template_body_templating_task", - ) - session.add(ti) - session.commit() - ti.render_templates() - task: AutoMLListDatasetOperator = ti.task - assert task.location == "location" - assert task.project_id == "project-id" - assert task.impersonation_chain == "impersonation-chain" - - -class TestAutoMLDatasetDeleteOperator: - @mock.patch("airflow.providers.google.cloud.operators.automl.CloudAutoMLHook") - def test_execute(self, mock_hook): - with pytest.warns(AirflowProviderDeprecationWarning): - op = AutoMLDeleteDatasetOperator( - dataset_id=DATASET_ID, - location=GCP_LOCATION, - project_id=GCP_PROJECT_ID, - task_id=TASK_ID, - ) - op.execute(context=None) - mock_hook.return_value.delete_dataset.assert_called_once_with( - location=GCP_LOCATION, - dataset_id=DATASET_ID, - metadata=(), - project_id=GCP_PROJECT_ID, - retry=DEFAULT, - timeout=None, - ) - - @pytest.mark.db_test - def test_templating(self, create_task_instance_of_operator, session): - with pytest.warns(AirflowProviderDeprecationWarning): - ti = create_task_instance_of_operator( - AutoMLDeleteDatasetOperator, - # Templated fields - dataset_id="{{ 'dataset-id' }}", - location="{{ 'location' }}", - project_id="{{ 'project-id' }}", - impersonation_chain="{{ 'impersonation-chain' }}", - # Other parameters - dag_id="test_template_body_templating_dag", - task_id="test_template_body_templating_task", - ) - session.add(ti) - session.commit() - ti.render_templates() - task: AutoMLDeleteDatasetOperator = ti.task - assert task.dataset_id == "dataset-id" - assert task.location == "location" - assert task.project_id == "project-id" - assert task.impersonation_chain == "impersonation-chain" diff --git a/providers/google/tests/unit/google/cloud/operators/test_mlengine.py b/providers/google/tests/unit/google/cloud/operators/test_mlengine.py deleted file mode 100644 index 5ca6cf8cea73d..0000000000000 --- a/providers/google/tests/unit/google/cloud/operators/test_mlengine.py +++ /dev/null @@ -1,77 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from __future__ import annotations - -from unittest.mock import MagicMock, patch - -import pytest - -from airflow.exceptions import AirflowProviderDeprecationWarning -from airflow.providers.google.cloud.operators.mlengine import MLEngineCreateModelOperator - -TEST_PROJECT_ID = "test-project-id" -TEST_MODEL_NAME = "test-model-name" -TEST_GCP_CONN_ID = "test-gcp-conn-id" -TEST_IMPERSONATION_CHAIN = ["ACCOUNT_1", "ACCOUNT_2", "ACCOUNT_3"] -TEST_MODEL = { - "name": TEST_MODEL_NAME, -} -MLENGINE_AI_PATH = "airflow.providers.google.cloud.operators.mlengine.{}" - - -class TestMLEngineCreateModelOperator: - @patch(MLENGINE_AI_PATH.format("MLEngineHook")) - def test_success_create_model(self, mock_hook): - with pytest.warns(AirflowProviderDeprecationWarning): - task = MLEngineCreateModelOperator( - task_id="task-id", - project_id=TEST_PROJECT_ID, - model=TEST_MODEL, - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, - ) - - task.execute(context=MagicMock()) - - mock_hook.assert_called_once_with( - gcp_conn_id=TEST_GCP_CONN_ID, - impersonation_chain=TEST_IMPERSONATION_CHAIN, - ) - mock_hook.return_value.create_model.assert_called_once_with( - project_id=TEST_PROJECT_ID, model=TEST_MODEL - ) - - @pytest.mark.db_test - def test_templating(self, create_task_instance_of_operator, session): - with pytest.warns(AirflowProviderDeprecationWarning): - ti = create_task_instance_of_operator( - MLEngineCreateModelOperator, - # Templated fields - project_id="{{ 'project_id' }}", - model="{{ 'model' }}", - impersonation_chain="{{ 'impersonation_chain' }}", - # Other parameters - dag_id="test_template_body_templating_dag", - task_id="test_template_body_templating_task", - ) - session.add(ti) - session.commit() - ti.render_templates() - task: MLEngineCreateModelOperator = ti.task - assert task.project_id == "project_id" - assert task.model == "model" - assert task.impersonation_chain == "impersonation_chain"