diff --git a/providers/yandex/src/airflow/providers/yandex/operators/dataproc.py b/providers/yandex/src/airflow/providers/yandex/operators/dataproc.py index 59fc8bf647a3b..d55569c8a23bc 100644 --- a/providers/yandex/src/airflow/providers/yandex/operators/dataproc.py +++ b/providers/yandex/src/airflow/providers/yandex/operators/dataproc.py @@ -14,12 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations from collections.abc import Iterable, Sequence from dataclasses import dataclass from typing import TYPE_CHECKING +import yandexcloud + from airflow.providers.yandex.hooks.dataproc import DataprocHook from airflow.providers.yandex.version_compat import BaseOperator @@ -54,6 +57,7 @@ class DataprocCreateClusterOperator(BaseOperator): Currently there are ru-central1-a, ru-central1-b and ru-central1-c. :param service_account_id: Service account id for the cluster. Service account can be created inside the folder. + :param environment: Environment for the cluster. Possible options: PRODUCTION, PRESTABLE. :param masternode_resource_preset: Resources preset (CPU+RAM configuration) for the primary node of the cluster. :param masternode_disk_size: Masternode storage size in GiB. @@ -96,6 +100,7 @@ class DataprocCreateClusterOperator(BaseOperator): Docs: https://cloud.yandex.com/docs/data-proc/concepts/logs :param initialization_actions: Set of init-actions to run when cluster starts. Docs: https://cloud.yandex.com/docs/data-proc/concepts/init-action + :param oslogin_enabled: Enable authorization via OS Login for cluster. :param labels: Cluster labels as key:value pairs. No more than 64 per resource. Docs: https://cloud.yandex.com/docs/resource-manager/concepts/labels """ @@ -113,6 +118,7 @@ def __init__( s3_bucket: str | None = None, zone: str = "ru-central1-b", service_account_id: str | None = None, + environment: str | None = None, masternode_resource_preset: str | None = None, masternode_disk_size: int | None = None, masternode_disk_type: str | None = None, @@ -138,6 +144,7 @@ def __init__( security_group_ids: Iterable[str] | None = None, log_group_id: str | None = None, initialization_actions: Iterable[InitializationAction] | None = None, + oslogin_enabled: bool = False, labels: dict[str, str] | None = None, **kwargs, ) -> None: @@ -156,6 +163,7 @@ def __init__( self.s3_bucket = s3_bucket self.zone = zone self.service_account_id = service_account_id + self.environment = environment self.masternode_resource_preset = masternode_resource_preset self.masternode_disk_size = masternode_disk_size self.masternode_disk_type = masternode_disk_type @@ -180,6 +188,7 @@ def __init__( self.security_group_ids = security_group_ids self.log_group_id = log_group_id self.initialization_actions = initialization_actions + self.oslogin_enabled = oslogin_enabled self.labels = labels self.hook: DataprocHook | None = None @@ -188,6 +197,11 @@ def execute(self, context: Context) -> dict: self.hook = DataprocHook( yandex_conn_id=self.yandex_conn_id, ) + kwargs_depends_on_version = {} + if yandexcloud.__version__ >= "0.350.0": + kwargs_depends_on_version.update( + {"oslogin_enabled": self.oslogin_enabled, "environment": self.environment} + ) operation_result = self.hook.dataproc_client.create_cluster( folder_id=self.folder_id, cluster_name=self.cluster_name, @@ -233,6 +247,7 @@ def execute(self, context: Context) -> dict: ] if self.initialization_actions else None, + **kwargs_depends_on_version, ) cluster_id = operation_result.response.id diff --git a/providers/yandex/tests/unit/yandex/operators/test_dataproc.py b/providers/yandex/tests/unit/yandex/operators/test_dataproc.py index 3a53a86daf08a..66cb883909e44 100644 --- a/providers/yandex/tests/unit/yandex/operators/test_dataproc.py +++ b/providers/yandex/tests/unit/yandex/operators/test_dataproc.py @@ -93,6 +93,7 @@ def setup_method(self): @patch("airflow.providers.yandex.utils.credentials.get_credentials") @patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @patch("yandexcloud._wrappers.dataproc.Dataproc.create_cluster") + @patch("yandexcloud.__version__", "0.308.0") def test_create_cluster(self, mock_create_cluster, *_): operator = DataprocCreateClusterOperator( task_id="create_cluster", @@ -154,6 +155,73 @@ def test_create_cluster(self, mock_create_cluster, *_): ] ) + @patch("airflow.providers.yandex.utils.credentials.get_credentials") + @patch(f"{BASEHOOK_PATCH_PATH}.get_connection") + @patch("yandexcloud._wrappers.dataproc.Dataproc.create_cluster") + @patch("yandexcloud.__version__", "0.350.0") + def test_create_cluster_with_350_sdk(self, mock_create_cluster, *_): + operator = DataprocCreateClusterOperator( + task_id="create_cluster", + ssh_public_keys=SSH_PUBLIC_KEYS, + folder_id=FOLDER_ID, + subnet_id=SUBNET_ID, + zone=AVAILABILITY_ZONE_ID, + connection_id=CONNECTION_ID, + s3_bucket=S3_BUCKET_NAME_FOR_LOGS, + cluster_image_version=CLUSTER_IMAGE_VERSION, + log_group_id=LOG_GROUP_ID, + ) + context = {"task_instance": MagicMock()} + operator.execute(context) + mock_create_cluster.assert_called_once_with( + cluster_description="", + cluster_image_version="1.4", + cluster_name=None, + computenode_count=0, + computenode_disk_size=None, + computenode_disk_type=None, + computenode_resource_preset=None, + computenode_max_hosts_count=None, + computenode_measurement_duration=None, + computenode_warmup_duration=None, + computenode_stabilization_duration=None, + computenode_preemptible=False, + computenode_cpu_utilization_target=None, + computenode_decommission_timeout=None, + datanode_count=1, + datanode_disk_size=None, + datanode_disk_type=None, + datanode_resource_preset=None, + folder_id="my_folder_id", + masternode_disk_size=None, + masternode_disk_type=None, + masternode_resource_preset=None, + s3_bucket="my_bucket_name", + service_account_id=None, + services=("HDFS", "YARN", "MAPREDUCE", "HIVE", "SPARK"), + ssh_public_keys=[ + "ssh-rsa AAA5B3NzaC1yc2EAA1ADA2ABA3AA4QCxO38tKA0XIs9ivPxt7AYdf3bgtAR1ow3Qkb9GPQ6wkFHQq" + "cFDe6faKCxH6iDRt2o4D8L8Bx6zN42uZSB0nf8jkIxFTcEU3mFSXEbWByg78ao3dMrAAj1tyr1H1pON6P0=" + ], + subnet_id="my_subnet_id", + zone="ru-central1-c", + log_group_id=LOG_GROUP_ID, + properties=None, + enable_ui_proxy=False, + host_group_ids=None, + security_group_ids=None, + labels=None, + initialization_actions=None, + environment=None, + oslogin_enabled=False, + ) + context["task_instance"].xcom_push.assert_has_calls( + [ + call(key="cluster_id", value=mock_create_cluster().response.id), + call(key="yandexcloud_connection_id", value=CONNECTION_ID), + ] + ) + @patch("airflow.providers.yandex.utils.credentials.get_credentials") @patch(f"{BASEHOOK_PATCH_PATH}.get_connection") @patch("yandexcloud._wrappers.dataproc.Dataproc.delete_cluster")