diff --git a/airflow-core/tests/unit/always/test_project_structure.py b/airflow-core/tests/unit/always/test_project_structure.py index e81476b02168f..9740fb37939c5 100644 --- a/airflow-core/tests/unit/always/test_project_structure.py +++ b/airflow-core/tests/unit/always/test_project_structure.py @@ -61,6 +61,7 @@ def test_providers_modules_should_have_tests(self): # We should make sure that one goes to 0 # TODO(potiuk) - check if that test actually tests something OVERLOOKED_TESTS = [ + "providers/alibaba/tests/unit/alibaba/test_version_compat.py", "providers/amazon/tests/unit/amazon/aws/auth_manager/datamodels/test_login.py", "providers/amazon/tests/unit/amazon/aws/auth_manager/security_manager/test_aws_security_manager_override.py", "providers/amazon/tests/unit/amazon/aws/executors/batch/test_batch_executor_config.py", diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index a74aa9364431d..189696c5de826 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1090,6 +1090,8 @@ masterType materializations Matomo matomo +MaxCompute +maxcompute Maxime MaxRuntimeInSeconds mb @@ -1220,6 +1222,7 @@ objectstorage observability od odbc +odps ok Okta okta diff --git a/providers/alibaba/provider.yaml b/providers/alibaba/provider.yaml index 8456d38e353b2..007c1be5351f6 100644 --- a/providers/alibaba/provider.yaml +++ b/providers/alibaba/provider.yaml @@ -78,6 +78,9 @@ operators: - integration-name: Alibaba Cloud AnalyticDB Spark python-modules: - airflow.providers.alibaba.cloud.operators.analyticdb_spark + - integration-name: Alibaba Cloud MaxCompute + python-modules: + - airflow.providers.alibaba.cloud.operators.maxcompute sensors: - integration-name: Alibaba Cloud OSS @@ -94,6 +97,12 @@ hooks: - integration-name: Alibaba Cloud AnalyticDB Spark python-modules: - airflow.providers.alibaba.cloud.hooks.analyticdb_spark + - integration-name: Alibaba Cloud + python-modules: + - airflow.providers.alibaba.cloud.hooks.base_alibaba + - integration-name: Alibaba Cloud MaxCompute + python-modules: + - airflow.providers.alibaba.cloud.hooks.maxcompute connection-types: @@ -101,6 +110,13 @@ connection-types: connection-type: oss - hook-class-name: airflow.providers.alibaba.cloud.hooks.analyticdb_spark.AnalyticDBSparkHook connection-type: adb_spark + - hook-class-name: airflow.providers.alibaba.cloud.hooks.base_alibaba.AlibabaBaseHook + connection-type: alibaba_cloud + - hook-class-name: airflow.providers.alibaba.cloud.hooks.maxcompute.MaxComputeHook + connection-type: maxcompute logging: - airflow.providers.alibaba.cloud.log.oss_task_handler.OSSTaskHandler + +extra-links: + - airflow.providers.alibaba.cloud.links.maxcompute.MaxComputeLogViewLink diff --git a/providers/alibaba/pyproject.toml b/providers/alibaba/pyproject.toml index 351caa3be7c78..1fa810555e951 100644 --- a/providers/alibaba/pyproject.toml +++ b/providers/alibaba/pyproject.toml @@ -61,6 +61,7 @@ dependencies = [ "oss2>=2.14.0", "alibabacloud_adb20211201>=1.0.0", "alibabacloud_tea_openapi>=0.3.7", + "pyodps>=0.12.2.2", ] [dependency-groups] diff --git a/providers/alibaba/src/airflow/providers/alibaba/cloud/exceptions.py b/providers/alibaba/src/airflow/providers/alibaba/cloud/exceptions.py new file mode 100644 index 0000000000000..0970c624af61b --- /dev/null +++ b/providers/alibaba/src/airflow/providers/alibaba/cloud/exceptions.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + + +class MaxComputeConfigurationException(Exception): + """Raised when MaxCompute project or endpoint is not configured properly.""" diff --git a/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/base_alibaba.py b/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/base_alibaba.py new file mode 100644 index 0000000000000..d583c7701176d --- /dev/null +++ b/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/base_alibaba.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import Any, NamedTuple + +from airflow.hooks.base import BaseHook + + +class AccessKeyCredentials(NamedTuple): + """ + A NamedTuple to store Alibaba Cloud Access Key credentials. + + :param access_key_id: The Access Key ID for Alibaba Cloud authentication. + :param access_key_secret: The Access Key Secret for Alibaba Cloud authentication. + """ + + access_key_id: str + access_key_secret: str + + +class AlibabaBaseHook(BaseHook): + """ + A base hook for Alibaba Cloud-related hooks. + + This hook provides a common interface for authenticating using Alibaba Cloud credentials. + + Supports Access Key-based authentication. + + :param alibaba_cloud_conn_id: The connection ID to use when fetching connection info. + """ + + conn_name_attr = "alibabacloud_conn_id" + default_conn_name = "alibabacloud_default" + conn_type = "alibaba_cloud" + hook_name = "Alibaba Cloud" + + def __init__( + self, + alibabacloud_conn_id: str = "alibabacloud_default", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.alibaba_cloud_conn_id = alibabacloud_conn_id + self.extras: dict = self.get_connection(self.alibaba_cloud_conn_id).extra_dejson + + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: + """Return connection widgets to add to connection form.""" + from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget + from flask_babel import lazy_gettext + from wtforms import PasswordField + + return { + "access_key_id": PasswordField(lazy_gettext("Access Key ID"), widget=BS3PasswordFieldWidget()), + "access_key_secret": PasswordField( + lazy_gettext("Access Key Secret"), widget=BS3PasswordFieldWidget() + ), + } + + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: + """Return custom field behaviour.""" + return super().get_ui_field_behaviour() + + def _get_field(self, field_name: str, default: Any = None) -> Any: + """Fetch a field from extras, and returns it.""" + value = self.extras.get(field_name) + return value if value is not None else default + + def get_access_key_credential(self) -> AccessKeyCredentials: + """ + Fetch Access Key Credential for authentication. + + :return: AccessKeyCredentials object containing access_key_id and access_key_secret. + """ + access_key_id = self._get_field("access_key_id", None) + access_key_secret = self._get_field("access_key_secret", None) + if not access_key_id: + raise ValueError("No access_key_id is specified.") + + if not access_key_secret: + raise ValueError("No access_key_secret is specified.") + + return AccessKeyCredentials(access_key_id, access_key_secret) diff --git a/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/maxcompute.py b/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/maxcompute.py new file mode 100644 index 0000000000000..2a249f123da13 --- /dev/null +++ b/providers/alibaba/src/airflow/providers/alibaba/cloud/hooks/maxcompute.py @@ -0,0 +1,242 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING, Any, Callable, TypeVar + +from odps import ODPS + +from airflow.providers.alibaba.cloud.exceptions import MaxComputeConfigurationException +from airflow.providers.alibaba.cloud.hooks.base_alibaba import AlibabaBaseHook + +if TYPE_CHECKING: + from odps.models import Instance + +RT = TypeVar("RT") + + +def fallback_to_default_project_endpoint(func: Callable[..., RT]) -> Callable[..., RT]: + """ + Provide fallback for MaxCompute project and endpoint to be used as a decorator. + + If the project or endpoint is None it will be replaced with the project from the + connection extra definition. + + :param func: function to wrap + :return: result of the function call + """ + + @functools.wraps(func) + def inner_wrapper(self, **kwargs) -> RT: + required_args = ("project", "endpoint") + for arg_name in required_args: + kwargs[arg_name] = kwargs.get(arg_name, getattr(self, arg_name)) + if not kwargs[arg_name]: + raise MaxComputeConfigurationException( + f'"{arg_name}" must be passed either as ' + "keyword parameter or as extra " + "in the MaxCompute connection definition. Both are not set!" + ) + + return func(self, **kwargs) + + return inner_wrapper + + +class MaxComputeHook(AlibabaBaseHook): + """ + Interact with Alibaba MaxCompute (previously known as ODPS). + + :param maxcompute_conn_id: The connection ID to use when fetching connection info. + """ + + conn_name_attr = "maxcompute_conn_id" + default_conn_name = "maxcompute_default" + conn_type = "maxcompute" + hook_name = "MaxCompute" + + def __init__(self, maxcompute_conn_id: str = "maxcompute_default", **kwargs) -> None: + self.maxcompute_conn_id = maxcompute_conn_id + super().__init__(alibabacloud_conn_id=maxcompute_conn_id, **kwargs) + + @classmethod + def get_connection_form_widgets(cls) -> dict[str, Any]: + """Return connection widgets to add to connection form.""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import StringField + + connection_form_widgets = super().get_connection_form_widgets() + + connection_form_widgets["project"] = StringField( + lazy_gettext("Project"), + widget=BS3TextFieldWidget(), + ) + connection_form_widgets["endpoint"] = StringField( + lazy_gettext("Endpoint"), + widget=BS3TextFieldWidget(), + ) + + return connection_form_widgets + + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: + """Return custom field behaviour.""" + return { + "hidden_fields": ["host", "schema", "login", "password", "port", "extra"], + "relabeling": {}, + } + + @property + def project(self) -> str: + """ + Returns project ID. + + :return: ID of the project + """ + return self._get_field("project") + + @property + def endpoint(self) -> str: + """ + Returns MaxCompute Endpoint. + + :return: Endpoint of the MaxCompute project + """ + return self._get_field("endpoint") + + @fallback_to_default_project_endpoint + def get_client(self, *, project: str, endpoint: str) -> ODPS: + """ + Get an authenticated MaxCompute ODPS Client. + + :param project_id: Project ID for the project which the client acts on behalf of. + :param location: Default location for jobs / datasets / tables. + """ + creds = self.get_access_key_credential() + + return ODPS( + creds.access_key_id, + creds.access_key_secret, + project=project, + endpoint=endpoint, + ) + + @fallback_to_default_project_endpoint + def run_sql( + self, + *, + sql: str, + project: str | None = None, + endpoint: str | None = None, + priority: int | None = None, + running_cluster: str | None = None, + hints: dict[str, Any] | None = None, + aliases: dict[str, str] | None = None, + default_schema: str | None = None, + quota_name: str | None = None, + ) -> Instance: + """ + Run a given SQL statement in MaxCompute. + + The method will submit your SQL statement to MaxCompute + and return the corresponding task Instance object. + + .. seealso:: https://pyodps.readthedocs.io/en/latest/base-sql.html#execute-sql + + :param sql: The SQL statement to run. + :param project: The project ID to use. + :param endpoint: The endpoint to use. + :param priority: The priority of the SQL statement ranges from 0 to 9, + applicable to projects with the job priority feature enabled. + Takes precedence over the `odps.instance.priority` setting from `hints`. + Defaults to 9. + See https://www.alibabacloud.com/help/en/maxcompute/user-guide/job-priority + for details. + :param running_cluster: The cluster to run the SQL statement on. + :param hints: Hints for setting runtime parameters. See + https://pyodps.readthedocs.io/en/latest/base-sql.html#id4 and + https://www.alibabacloud.com/help/en/maxcompute/user-guide/flag-parameters + for details. + :param aliases: Aliases for the SQL statement. + :param default_schema: The default schema to use. + :param quota_name: The quota name to use. + Defaults to project default quota if not specified. + :return: The MaxCompute task instance. + """ + client = self.get_client(project=project, endpoint=endpoint) + + if priority is None and hints is not None: + priority = hints.get("odps.instance.priority") + + return client.run_sql( + sql=sql, + priority=priority, + running_cluster=running_cluster, + hints=hints, + aliases=aliases, + default_schema=default_schema, + quota_name=quota_name, + ) + + @fallback_to_default_project_endpoint + def get_instance( + self, + *, + instance_id: str, + project: str | None = None, + endpoint: str | None = None, + ) -> Instance: + """ + Get a MaxCompute task instance. + + .. seealso:: https://pyodps.readthedocs.io/en/latest/base-instances.html#instances + + :param instance_id: The ID of the instance to get. + :param project: The project ID to use. + :param endpoint: The endpoint to use. + :return: The MaxCompute task instance. + :raises ValueError: If the instance does not exist. + """ + client = self.get_client(project=project, endpoint=endpoint) + + return client.get_instance(id_=instance_id, project=project) + + @fallback_to_default_project_endpoint + def stop_instance( + self, + *, + instance_id: str, + project: str | None = None, + endpoint: str | None = None, + ) -> None: + """ + Stop a MaxCompute task instance. + + :param instance_id: The ID of the instance to stop. + :param project: The project ID to use. + :param endpoint: The endpoint to use. + """ + client = self.get_client(project=project, endpoint=endpoint) + + try: + client.stop_instance(id_=instance_id, project=project) + self.log.info("Instance %s stop requested.", instance_id) + except Exception: + self.log.exception("Failed to stop instance %s.", instance_id) + raise diff --git a/providers/alibaba/src/airflow/providers/alibaba/cloud/links/__init__.py b/providers/alibaba/src/airflow/providers/alibaba/cloud/links/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/alibaba/src/airflow/providers/alibaba/cloud/links/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/alibaba/src/airflow/providers/alibaba/cloud/links/maxcompute.py b/providers/alibaba/src/airflow/providers/alibaba/cloud/links/maxcompute.py new file mode 100644 index 0000000000000..5147b5fc0f6d4 --- /dev/null +++ b/providers/alibaba/src/airflow/providers/alibaba/cloud/links/maxcompute.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.alibaba.version_compat import AIRFLOW_V_3_0_PLUS + +if TYPE_CHECKING: + from airflow.models import BaseOperator + from airflow.models.taskinstancekey import TaskInstanceKey + from airflow.utils.context import Context + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseOperatorLink + from airflow.sdk.execution_time.xcom import XCom +else: + from airflow.models.baseoperatorlink import BaseOperatorLink # type: ignore[no-redef] + from airflow.models.xcom import XCom # type: ignore[no-redef] + + +class MaxComputeLogViewLink(BaseOperatorLink): + """Helper class for constructing MaxCompute Log View Link.""" + + name = "MaxCompute Log View" + key = "maxcompute_log_view" + + def get_link( + self, + operator: BaseOperator, + *, + ti_key: TaskInstanceKey, + ) -> str: + url = XCom.get_value(key=self.key, ti_key=ti_key) + if not url: + return "" + + return url + + @staticmethod + def persist( + context: Context, + task_instance: BaseOperator, + log_view_url: str, + ): + """ + Persist the log view URL to XCom for later retrieval. + + :param context: The context of the task instance. + :param task_instance: The task instance. + :param log_view_url: The log view URL to persist. + """ + task_instance.xcom_push( + context, + key=MaxComputeLogViewLink.key, + value=log_view_url, + ) diff --git a/providers/alibaba/src/airflow/providers/alibaba/cloud/operators/maxcompute.py b/providers/alibaba/src/airflow/providers/alibaba/cloud/operators/maxcompute.py new file mode 100644 index 0000000000000..c8424c028464e --- /dev/null +++ b/providers/alibaba/src/airflow/providers/alibaba/cloud/operators/maxcompute.py @@ -0,0 +1,145 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Alibaba Cloud MaxCompute operators.""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from airflow.models import BaseOperator +from airflow.providers.alibaba.cloud.hooks.maxcompute import MaxComputeHook +from airflow.providers.alibaba.cloud.links.maxcompute import MaxComputeLogViewLink + +if TYPE_CHECKING: + from odps.models import Instance + + from airflow.utils.context import Context + + +class MaxComputeSQLOperator(BaseOperator): + """ + Executes an SQL statement in MaxCompute. + + Waits for the SQL task instance to complete and returns instance id. + + :param sql: The SQL statement to run. + :param project: The project ID to use. + :param endpoint: The endpoint to use. + :param priority: The priority of the SQL statement ranges from 0 to 9, + applicable to projects with the job priority feature enabled. + Takes precedence over the `odps.instance.priority` setting from `hints`. + Defaults to 9. + See https://www.alibabacloud.com/help/en/maxcompute/user-guide/job-priority + for details. + :param running_cluster: The cluster to run the SQL statement on. + :param hints: Hints for setting runtime parameters. See + https://pyodps.readthedocs.io/en/latest/base-sql.html#id4 and + https://www.alibabacloud.com/help/en/maxcompute/user-guide/flag-parameters + for details. + :param aliases: Aliases for the SQL statement. + :param default_schema: The default schema to use. + :param quota_name: The quota name to use. + Defaults to project default quota if not specified. + :param alibabacloud_conn_id: The connection ID to use. Defaults to + `alibabacloud_default` if not specified. + :param cancel_on_kill: Flag which indicates whether to stop running instance + or not when task is killed. Default is True. + """ + + template_fields: Sequence[str] = ( + "sql", + "project", + "endpoint", + "priority", + "running_cluster", + "hints", + "aliases", + "default_schema", + "quota_name", + "alibabacloud_conn_id", + ) + template_ext: Sequence[str] = (".sql",) + template_fields_renderers = {"sql": "sql"} + operator_extra_links = (MaxComputeLogViewLink(),) + + def __init__( + self, + *, + sql: str, + project: str | None = None, + endpoint: str | None = None, + priority: int | None = None, + running_cluster: str | None = None, + hints: dict[str, str] | None = None, + aliases: dict[str, str] | None = None, + default_schema: str | None = None, + quota_name: str | None = None, + alibabacloud_conn_id: str = "alibabacloud_default", + cancel_on_kill: bool = True, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.sql = sql + self.project = project + self.endpoint = endpoint + self.priority = priority + self.running_cluster = running_cluster + self.hints = hints + self.aliases = aliases + self.default_schema = default_schema + self.quota_name = quota_name + self.alibabacloud_conn_id = alibabacloud_conn_id + self.cancel_on_kill = cancel_on_kill + self.hook: MaxComputeHook | None = None + self.instance: Instance | None = None + + def execute(self, context: Context) -> str: + self.hook = MaxComputeHook(alibabacloud_conn_id=self.alibabacloud_conn_id) + + self.instance = self.hook.run_sql( + sql=self.sql, + project=self.project, + endpoint=self.endpoint, + priority=self.priority, + running_cluster=self.running_cluster, + hints=self.hints, + aliases=self.aliases, + default_schema=self.default_schema, + quota_name=self.quota_name, + ) + + MaxComputeLogViewLink.persist( + context=context, task_instance=self, log_view_url=self.instance.get_logview_address() + ) + + self.instance.wait_for_success() + + return self.instance.id + + def on_kill(self) -> None: + instance_id = self.instance.id if self.instance else None + + if instance_id and self.hook and self.cancel_on_kill: + self.hook.stop_instance( + instance_id=instance_id, + project=self.project, + endpoint=self.endpoint, + ) + else: + self.log.info("Skipping to stop instance: %s:%s.%s", self.project, self.endpoint, instance_id) diff --git a/providers/alibaba/src/airflow/providers/alibaba/get_provider_info.py b/providers/alibaba/src/airflow/providers/alibaba/get_provider_info.py index 44cad21c88cc4..39e9e8de250ad 100644 --- a/providers/alibaba/src/airflow/providers/alibaba/get_provider_info.py +++ b/providers/alibaba/src/airflow/providers/alibaba/get_provider_info.py @@ -50,6 +50,10 @@ def get_provider_info(): "integration-name": "Alibaba Cloud AnalyticDB Spark", "python-modules": ["airflow.providers.alibaba.cloud.operators.analyticdb_spark"], }, + { + "integration-name": "Alibaba Cloud MaxCompute", + "python-modules": ["airflow.providers.alibaba.cloud.operators.maxcompute"], + }, ], "sensors": [ { @@ -70,6 +74,14 @@ def get_provider_info(): "integration-name": "Alibaba Cloud AnalyticDB Spark", "python-modules": ["airflow.providers.alibaba.cloud.hooks.analyticdb_spark"], }, + { + "integration-name": "Alibaba Cloud", + "python-modules": ["airflow.providers.alibaba.cloud.hooks.base_alibaba"], + }, + { + "integration-name": "Alibaba Cloud MaxCompute", + "python-modules": ["airflow.providers.alibaba.cloud.hooks.maxcompute"], + }, ], "connection-types": [ { @@ -80,6 +92,15 @@ def get_provider_info(): "hook-class-name": "airflow.providers.alibaba.cloud.hooks.analyticdb_spark.AnalyticDBSparkHook", "connection-type": "adb_spark", }, + { + "hook-class-name": "airflow.providers.alibaba.cloud.hooks.base_alibaba.AlibabaBaseHook", + "connection-type": "alibaba_cloud", + }, + { + "hook-class-name": "airflow.providers.alibaba.cloud.hooks.maxcompute.MaxComputeHook", + "connection-type": "maxcompute", + }, ], "logging": ["airflow.providers.alibaba.cloud.log.oss_task_handler.OSSTaskHandler"], + "extra-links": ["airflow.providers.alibaba.cloud.links.maxcompute.MaxComputeLogViewLink"], } diff --git a/providers/alibaba/src/airflow/providers/alibaba/version_compat.py b/providers/alibaba/src/airflow/providers/alibaba/version_compat.py new file mode 100644 index 0000000000000..48d122b669696 --- /dev/null +++ b/providers/alibaba/src/airflow/providers/alibaba/version_compat.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# NOTE! THIS FILE IS COPIED MANUALLY IN OTHER PROVIDERS DELIBERATELY TO AVOID ADDING UNNECESSARY +# DEPENDENCIES BETWEEN PROVIDERS. IF YOU WANT TO ADD CONDITIONAL CODE IN YOUR PROVIDER THAT DEPENDS +# ON AIRFLOW VERSION, PLEASE COPY THIS FILE TO THE ROOT PACKAGE OF YOUR PROVIDER AND IMPORT +# THOSE CONSTANTS FROM IT RATHER THAN IMPORTING THEM FROM ANOTHER PROVIDER OR TEST CODE +# +from __future__ import annotations + + +def get_base_airflow_version_tuple() -> tuple[int, int, int]: + from packaging.version import Version + + from airflow import __version__ + + airflow_version = Version(__version__) + return airflow_version.major, airflow_version.minor, airflow_version.micro + + +AIRFLOW_V_3_0_PLUS = get_base_airflow_version_tuple() >= (3, 0, 0) diff --git a/providers/alibaba/tests/system/alibaba/example_maxcompute_sql.py b/providers/alibaba/tests/system/alibaba/example_maxcompute_sql.py new file mode 100644 index 0000000000000..c9fb6bb3b22fd --- /dev/null +++ b/providers/alibaba/tests/system/alibaba/example_maxcompute_sql.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.alibaba.cloud.operators.maxcompute import MaxComputeSQLOperator + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") +DAG_ID = "maxcompute_sql_dag" + +SQL = "SELECT 1" + +with DAG( + dag_id=DAG_ID, + start_date=datetime(2021, 1, 1), + schedule="@once", + tags=["example", "maxcompute"], + catchup=False, +) as dag: + run_sql = MaxComputeSQLOperator( + task_id="run_sql", + sql=SQL, + ) + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/providers/alibaba/tests/unit/alibaba/cloud/hooks/test_base_alibaba.py b/providers/alibaba/tests/unit/alibaba/cloud/hooks/test_base_alibaba.py new file mode 100644 index 0000000000000..b32a6296451b7 --- /dev/null +++ b/providers/alibaba/tests/unit/alibaba/cloud/hooks/test_base_alibaba.py @@ -0,0 +1,53 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from airflow.providers.alibaba.cloud.hooks.base_alibaba import AccessKeyCredentials, AlibabaBaseHook + +BASE_ALIBABA_HOOK_MODULE = "airflow.providers.alibaba.cloud.hooks.base_alibaba.{}" +MOCK_MAXCOMPUTE_CONN_ID = "mock_id" +MOCK_ACCESS_KEY_ID = "mock_access_key_id" +MOCK_ACCESS_KEY_SECRET = "mock_access_key_secret" + + +class TestAlibabaBaseHook: + def setup_method(self): + with mock.patch( + BASE_ALIBABA_HOOK_MODULE.format("AlibabaBaseHook.get_connection"), + ) as mock_get_connection: + mock_conn = mock.MagicMock() + mock_conn.extra_dejson = { + "access_key_id": MOCK_ACCESS_KEY_ID, + "access_key_secret": MOCK_ACCESS_KEY_SECRET, + } + + mock_get_connection.return_value = mock_conn + self.hook = AlibabaBaseHook(alibabacloud_conn_id=MOCK_MAXCOMPUTE_CONN_ID) + + def test_get_access_key_credential(self): + creds = AccessKeyCredentials( + access_key_id=MOCK_ACCESS_KEY_ID, + access_key_secret=MOCK_ACCESS_KEY_SECRET, + ) + + creds = self.hook.get_access_key_credential() + + assert creds.access_key_id == MOCK_ACCESS_KEY_ID + assert creds.access_key_secret == MOCK_ACCESS_KEY_SECRET diff --git a/providers/alibaba/tests/unit/alibaba/cloud/hooks/test_maxcompute.py b/providers/alibaba/tests/unit/alibaba/cloud/hooks/test_maxcompute.py new file mode 100644 index 0000000000000..f5a1e59c639a2 --- /dev/null +++ b/providers/alibaba/tests/unit/alibaba/cloud/hooks/test_maxcompute.py @@ -0,0 +1,117 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from airflow.providers.alibaba.cloud.hooks.maxcompute import MaxComputeHook + +MAXCOMPUTE_HOOK_MODULE = "airflow.providers.alibaba.cloud.hooks.maxcompute.MaxComputeHook.{}" +MOCK_MAXCOMPUTE_CONN_ID = "mock_id" +MOCK_MAXCOMPUTE_PROJECT = "mock_project" +MOCK_MAXCOMPUTE_ENDPOINT = "mock_endpoint" + + +class TestMaxComputeHook: + def setup_method(self): + with mock.patch( + MAXCOMPUTE_HOOK_MODULE.format("get_connection"), + ) as mock_get_connection: + mock_conn = mock.MagicMock() + mock_conn.extra_dejson = { + "access_key_id": "mock_access_key_id", + "access_key_secret": "mock_access_key_secret", + "project": MOCK_MAXCOMPUTE_PROJECT, + "endpoint": MOCK_MAXCOMPUTE_ENDPOINT, + } + + mock_get_connection.return_value = mock_conn + self.hook = MaxComputeHook(maxcompute_conn_id=MOCK_MAXCOMPUTE_CONN_ID) + + @mock.patch(MAXCOMPUTE_HOOK_MODULE.format("get_client")) + def test_run_sql(self, mock_get_client): + mock_instance = mock.MagicMock() + mock_client = mock.MagicMock() + mock_client.run_sql.return_value = mock_instance + mock_get_client.return_value = mock_client + + sql = "SELECT 1" + priority = 1 + running_cluster = "mock_running_cluster" + hints = {"hint_key": "hint_value"} + aliases = {"alias_key": "alias_value"} + default_schema = "mock_default_schema" + quota_name = "mock_quota_name" + + instance = self.hook.run_sql( + sql=sql, + priority=priority, + running_cluster=running_cluster, + hints=hints, + aliases=aliases, + default_schema=default_schema, + quota_name=quota_name, + ) + + assert instance == mock_instance + + assert mock_client.run_sql.asssert_called_once_with( + sql=sql, + priority=priority, + running_cluster=running_cluster, + hints=hints, + aliases=aliases, + default_schema=default_schema, + quota_name=quota_name, + ) + + @mock.patch(MAXCOMPUTE_HOOK_MODULE.format("get_client")) + def test_get_instance(self, mock_get_client): + mock_client = mock.MagicMock() + mock_client.exist_instance.return_value = True + mock_instance = mock.MagicMock() + mock_client.get_instance.return_value = mock_instance + mock_get_client.return_value = mock_client + instance_id = "mock_instance_id" + + instance = self.hook.get_instance( + instance_id=instance_id, + project=MOCK_MAXCOMPUTE_PROJECT, + endpoint=MOCK_MAXCOMPUTE_ENDPOINT, + ) + + mock_client.get_instance.assert_called_once_with( + id_=instance_id, + project=MOCK_MAXCOMPUTE_PROJECT, + ) + + assert instance == mock_instance + + @mock.patch(MAXCOMPUTE_HOOK_MODULE.format("get_client")) + def test_stop_instance_success(self, mock_get_client): + mock_client = mock.MagicMock() + mock_get_client.return_value = mock_client + instance_id = "mock_instance_id" + + self.hook.stop_instance( + instance_id=instance_id, + project=MOCK_MAXCOMPUTE_PROJECT, + endpoint=MOCK_MAXCOMPUTE_ENDPOINT, + ) + + mock_client.stop_instance.assert_called_once() diff --git a/providers/alibaba/tests/unit/alibaba/cloud/links/__init__.py b/providers/alibaba/tests/unit/alibaba/cloud/links/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/alibaba/tests/unit/alibaba/cloud/links/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/alibaba/tests/unit/alibaba/cloud/links/test_maxcompute.py b/providers/alibaba/tests/unit/alibaba/cloud/links/test_maxcompute.py new file mode 100644 index 0000000000000..2820cf7c1223a --- /dev/null +++ b/providers/alibaba/tests/unit/alibaba/cloud/links/test_maxcompute.py @@ -0,0 +1,66 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.providers.alibaba.cloud.links.maxcompute import MaxComputeLogViewLink + +MAXCOMPUTE_LINK_MODULE = "airflow.providers.alibaba.cloud.links.maxcompute.{}" +MOCK_TASK_ID = "run_sql" +MOCK_SQL = "SELECT 1" +MOCK_INSTANCE_ID = "mock_instance_id" + + +class TestMaxComputeLogViewLink: + @pytest.mark.parametrize( + "xcom_value, expected_link", + [ + pytest.param("http://mock_url.com", "http://mock_url.com", id="has-log-link"), + pytest.param(None, "", id="no-log-link"), + ], + ) + @mock.patch(MAXCOMPUTE_LINK_MODULE.format("XCom")) + def test_get_link(self, mock_xcom, xcom_value, expected_link): + mock_xcom.get_value.return_value = xcom_value + + link = MaxComputeLogViewLink().get_link( + operator=mock.MagicMock(), + ti_key=mock.MagicMock(), + ) + + assert link == expected_link + + def test_persist(self): + mock_context = mock.MagicMock() + mock_task_instance = mock.MagicMock() + mock_url = "mock_url" + + MaxComputeLogViewLink.persist( + context=mock_context, + task_instance=mock_task_instance, + log_view_url=mock_url, + ) + + mock_task_instance.xcom_push.assert_called_once_with( + mock_context, + key=MaxComputeLogViewLink.key, + value=mock_url, + ) diff --git a/providers/alibaba/tests/unit/alibaba/cloud/operators/test_maxcompute.py b/providers/alibaba/tests/unit/alibaba/cloud/operators/test_maxcompute.py new file mode 100644 index 0000000000000..4bbd6f1cd66f7 --- /dev/null +++ b/providers/alibaba/tests/unit/alibaba/cloud/operators/test_maxcompute.py @@ -0,0 +1,88 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from airflow.providers.alibaba.cloud.operators.maxcompute import MaxComputeSQLOperator + +MAXCOMPUTE_OPERATOR_MODULE = "airflow.providers.alibaba.cloud.operators.maxcompute.{}" +MOCK_TASK_ID = "run_sql" +MOCK_SQL = "SELECT 1" +MOCK_INSTANCE_ID = "mock_instance_id" + + +class TestMaxComputeSQLOperator: + @mock.patch(MAXCOMPUTE_OPERATOR_MODULE.format("MaxComputeLogViewLink")) + @mock.patch(MAXCOMPUTE_OPERATOR_MODULE.format("MaxComputeHook")) + def test_execute(self, mock_hook, mock_log_view_link): + instance_mock = mock.MagicMock() + instance_mock.id = MOCK_INSTANCE_ID + instance_mock.get_logview_address.return_value = "http://mock_logview_address" + + mock_hook.return_value.run_sql.return_value = instance_mock + + op = MaxComputeSQLOperator( + task_id=MOCK_TASK_ID, + sql=MOCK_SQL, + ) + + instance_id = op.execute(context=mock.MagicMock()) + + assert instance_id == instance_mock.id + + mock_hook.return_value.run_sql.assert_called_once_with( + project=op.project, + sql=op.sql, + endpoint=op.endpoint, + priority=op.priority, + running_cluster=op.running_cluster, + hints=op.hints, + aliases=op.aliases, + default_schema=op.default_schema, + quota_name=op.quota_name, + ) + + mock_log_view_link.persist.assert_called_once_with( + context=mock.ANY, + task_instance=op, + log_view_url=instance_mock.get_logview_address.return_value, + ) + + @mock.patch(MAXCOMPUTE_OPERATOR_MODULE.format("MaxComputeHook")) + def test_on_kill(self, mock_hook): + instance_mock = mock.MagicMock() + instance_mock.id = MOCK_INSTANCE_ID + + mock_hook.return_value.run_sql.return_value = instance_mock + + op = MaxComputeSQLOperator( + task_id=MOCK_TASK_ID, + sql=MOCK_SQL, + cancel_on_kill=False, + ) + op.execute(context=mock.MagicMock()) + + op.on_kill() + mock_hook.return_value.cancel_job.assert_not_called() + + op.cancel_on_kill = True + op.on_kill() + mock_hook.return_value.stop_instance.assert_called_once_with( + instance_id=instance_mock.id, project=op.project, endpoint=op.endpoint + ) diff --git a/providers/alibaba/tests/unit/alibaba/cloud/test_exceptions.py b/providers/alibaba/tests/unit/alibaba/cloud/test_exceptions.py new file mode 100644 index 0000000000000..924a6cb2852a6 --- /dev/null +++ b/providers/alibaba/tests/unit/alibaba/cloud/test_exceptions.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.alibaba.cloud.hooks.maxcompute import MaxComputeConfigurationException + + +def test_maxcompute_configuration_exception_message(): + message = "Project or endpoint missing" + with pytest.raises(MaxComputeConfigurationException) as e: + raise MaxComputeConfigurationException(message) + + assert str(e.value) == message