diff --git a/airflow-core/tests/unit/always/test_project_structure.py b/airflow-core/tests/unit/always/test_project_structure.py index 1fab1d8e49800..fb55ed7e9c31c 100644 --- a/airflow-core/tests/unit/always/test_project_structure.py +++ b/airflow-core/tests/unit/always/test_project_structure.py @@ -470,6 +470,7 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest "airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaBaseOperator", "airflow.providers.google.cloud.operators.vertex_ai.custom_job.CustomTrainingJobBaseOperator", "airflow.providers.google.cloud.operators.vertex_ai.ray.RayBaseOperator", + "airflow.providers.google.cloud.operators.ray.RayJobBaseOperator", "airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator", "airflow.providers.google.marketing_platform.operators.search_ads._GoogleSearchAdsBaseOperator", } diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 078b02b1dca5c..17ee36adc6407 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -981,6 +981,7 @@ Jira jira jitter JobComplete +JobDetails JobExists jobflow jobId diff --git a/providers/google/docs/operators/cloud/ray.rst b/providers/google/docs/operators/cloud/ray.rst new file mode 100644 index 0000000000000..2f234a988e6b1 --- /dev/null +++ b/providers/google/docs/operators/cloud/ray.rst @@ -0,0 +1,90 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +Ray Job Operators +================= + +The Ray Job operators provide a high-level interface for interacting with remote Ray clusters +using the Ray Jobs API. These operators can be used with clusters running on Google Cloud Vertex AI Ray, +GKE (self-managed Ray clusters) or any Ray cluster reachable through a dashboard address or Ray Client address. + +The operators allow you to submit jobs, monitor their progress, retrieve logs, +and manage job lifecycle from Airflow. + +Submitting Ray Jobs +^^^^^^^^^^^^^^^^^^^ + +The :class:`~airflow.providers.google.cloud.operators.ray.RaySubmitJobOperator` +submits a job to a Ray cluster and optionally waits for completion. + +It supports waiting for job completion with ``wait_for_job_done`` +and retrieving logs after completion with ``get_job_logs`` parameters. + +.. exampleinclude:: /../../google/tests/system/google/cloud/ray/example_ray_job.py + :language: python + :dedent: 4 + :start-after: [START how_to_ray_submit_job] + :end-before: [END how_to_ray_submit_job] + +Stopping Ray Jobs +^^^^^^^^^^^^^^^^^ + +Use :class:`~airflow.providers.google.cloud.operators.ray.RayStopJobOperator` +to stop a running Ray job identified by its job ID. + +.. exampleinclude:: /../../google/tests/system/google/cloud/ray/example_ray_job.py + :language: python + :dedent: 4 + :start-after: [START how_to_ray_stop_job] + :end-before: [END how_to_ray_stop_job] + +Deleting Ray Jobs +^^^^^^^^^^^^^^^^^ + +Use :class:`~airflow.providers.google.cloud.operators.ray.RayDeleteJobOperator` +to delete a job and its metadata after it reaches a terminal state. + +.. exampleinclude:: /../../google/tests/system/google/cloud/ray/example_ray_job.py + :language: python + :dedent: 4 + :start-after: [START how_to_ray_delete_job] + :end-before: [END how_to_ray_delete_job] + +Retrieving Job Information +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :class:`~airflow.providers.google.cloud.operators.ray.RayGetJobInfoOperator` +retrieves detailed information about a Ray job, including status, timestamps, +entrypoint, metadata, and runtime environment. + +.. exampleinclude:: /../../google/tests/system/google/cloud/ray/example_ray_job.py + :language: python + :dedent: 4 + :start-after: [START how_to_ray_get_job_info] + :end-before: [END how_to_ray_get_job_info] + +Listing Jobs +^^^^^^^^^^^^ + +Use :class:`~airflow.providers.google.cloud.operators.ray.RayListJobsOperator` +to list all jobs that have run on the cluster. + +.. exampleinclude:: /../../google/tests/system/google/cloud/ray/example_ray_job.py + :language: python + :dedent: 4 + :start-after: [START how_to_ray_list_jobs] + :end-before: [END how_to_ray_list_jobs] diff --git a/providers/google/provider.yaml b/providers/google/provider.yaml index 46f4c266e6e83..0d96157eb1977 100644 --- a/providers/google/provider.yaml +++ b/providers/google/provider.yaml @@ -462,6 +462,11 @@ integrations: how-to-guide: - /docs/apache-airflow-providers-google/operators/cloud/gen_ai.rst tags: [gcp] + - integration-name: Google Ray + external-doc-url: https://docs.cloud.google.com/vertex-ai/docs/open-source/ray-on-vertex-ai/overview + how-to-guide: + - /docs/apache-airflow-providers-google/operators/cloud/ray.rst + tags: [gcp] operators: - integration-name: Google Ads @@ -624,6 +629,9 @@ operators: - integration-name: Google Cloud Generative AI python-modules: - airflow.providers.google.cloud.operators.gen_ai + - integration-name: Google Ray + python-modules: + - airflow.providers.google.cloud.operators.ray sensors: - integration-name: Google BigQuery @@ -905,6 +913,9 @@ hooks: - integration-name: Google Cloud Generative AI python-modules: - airflow.providers.google.cloud.hooks.gen_ai + - integration-name: Google Ray + python-modules: + - airflow.providers.google.cloud.hooks.ray bundles: - integration-name: Google Cloud Storage (GCS) @@ -1258,6 +1269,7 @@ extra-links: - airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterListLink - airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaTopicLink - airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaConsumerGroupLink + - airflow.providers.google.cloud.links.ray.RayJobLink secrets-backends: - airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/ray.py b/providers/google/src/airflow/providers/google/cloud/hooks/ray.py new file mode 100644 index 0000000000000..bda6ab086aee1 --- /dev/null +++ b/providers/google/src/airflow/providers/google/cloud/hooks/ray.py @@ -0,0 +1,234 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Ray Job hook.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any +from urllib.parse import urlparse + +from ray.job_submission import JobSubmissionClient + +from airflow.providers.google.common.hooks.base_google import GoogleBaseHook + +if TYPE_CHECKING: + from ray.dashboard.modules.job.common import JobStatus + from ray.dashboard.modules.job.pydantic_models import JobDetails + +VERTEX_RAY_DOMAIN = "aiplatform-training.googleusercontent.com" + + +class RayJobHook(GoogleBaseHook): + """Hook for Jobs APIs.""" + + def _is_vertex_ray_address(self, address: str) -> bool: + """Return True if address points to Vertex Ray dashboard host.""" + parsed = urlparse(address if "://" in address else f"https://{address}") + hostname = parsed.hostname + if not hostname: + return False + return hostname.endswith(VERTEX_RAY_DOMAIN) + + def get_client(self, address: str) -> JobSubmissionClient: + """ + Create a client for submitting and interacting with jobs on a remote cluster. + + :param address: Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://:10001), + or "auto", or "localhost:". + """ + if self._is_vertex_ray_address(address): + return JobSubmissionClient(f"vertex_ray://{address}") + return JobSubmissionClient(address=address) + + def serialize_job_obj(self, job_obj: JobDetails) -> dict: + """Serialize JobDetails to a plain dict.""" + if hasattr(job_obj, "model_dump"): # Pydantic v2 + return job_obj.model_dump(exclude_none=True) + if hasattr(job_obj, "dict"): # Pydantic v1 + return job_obj.dict(exclude_none=True) + return dict(job_obj) + + def submit_job( + self, + entrypoint: str, + cluster_address: str, + runtime_env: dict[str, Any] | None = None, + metadata: dict[str, str] | None = None, + submission_id: str | None = None, + entrypoint_num_cpus: int | float | None = None, + entrypoint_num_gpus: int | float | None = None, + entrypoint_memory: int | None = None, + entrypoint_resources: dict[str, float] | None = None, + ) -> str: + """ + Submit and execute Job on Ray cluster. + + When a job is submitted, it runs once to completion or failure. Retries or + different runs with different parameters should be handled by the + submitter. Jobs are bound to the lifetime of a Ray cluster, so if the + cluster goes down, all running jobs on that cluster will be terminated. + + :param entrypoint: Required. The shell command to run for this job. + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://:10001), + or "auto", or "localhost:". + :param submission_id: A unique ID for this job. + :param runtime_env: The runtime environment to install and run this job in. + :param metadata: Arbitrary data to store along with this job. + :param entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_num_gpus: The quantity of GPUs to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_memory: The quantity of memory to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. Defaults to 0. + :param entrypoint_resources: The quantity of custom resources to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. + """ + job_id = self.get_client(address=cluster_address).submit_job( + entrypoint=entrypoint, + runtime_env=runtime_env, + metadata=metadata, + submission_id=submission_id, + entrypoint_num_cpus=entrypoint_num_cpus, + entrypoint_num_gpus=entrypoint_num_gpus, + entrypoint_memory=entrypoint_memory, + entrypoint_resources=entrypoint_resources, + ) + return job_id + + def stop_job( + self, + job_id: str, + cluster_address: str, + ) -> bool: + """ + Stop Job on Ray cluster. + + :param job_id: Required. The job ID or submission ID for the job to be stopped. + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://:10001), + or "auto", or "localhost:". + :return: True if the job was stopped, otherwise False. + """ + return self.get_client(address=cluster_address).stop_job(job_id=job_id) + + def delete_job( + self, + job_id: str, + cluster_address: str, + ) -> bool: + """ + Delete Job on Ray cluster in a terminal state and all of its associated data. + + If the job is not already in a terminal state, raises an error. + This does not delete the job logs from disk. + Submitting a job with the same submission ID as a previously + deleted job is not supported and may lead to unexpected behavior. + + :param job_id: Required. The job ID or submission ID for the job to be deleted. + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://:10001), + or "auto", or "localhost:". + :return: True if the job was deleted, otherwise False. + """ + return self.get_client(address=cluster_address).delete_job(job_id=job_id) + + def get_job_info( + self, + job_id: str, + cluster_address: str, + ) -> JobDetails: + """ + Get the latest status and other information associated with a Job on Ray cluster. + + :param job_id: Required. The job ID or submission ID for the job to be retrieved. + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://:10001), + or "auto", or "localhost:". + :return: The JobDetails for the job. + """ + return self.get_client(address=cluster_address).get_job_info(job_id=job_id) + + def list_jobs( + self, + cluster_address: str, + ) -> list[JobDetails]: + """ + List all jobs along with their status and other information. + + Lists all jobs that have ever run on the cluster, including jobs that are + currently running and jobs that are no longer running. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://:10001), + or "auto", or "localhost:". + """ + return self.get_client(address=cluster_address).list_jobs() + + def get_job_status( + self, + job_id: str, + cluster_address: str, + ) -> JobStatus: + """ + Get the most recent status of a Job on Ray cluster. + + :param job_id: Required. The job ID or submission ID for the job to be retrieved. + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://:10001), + or "auto", or "localhost:". + :return: The JobStatus of the job. + """ + return self.get_client(address=cluster_address).get_job_status(job_id=job_id) + + def get_job_logs( + self, + job_id: str, + cluster_address: str, + ) -> str: + """ + Get all logs produced by a Job on Ray cluster. + + :param job_id: Required. The job ID or submission ID for the job to be retrieved. + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://:10001), + or "auto", or "localhost:". + :return: A string containing the full logs of the job. + """ + return self.get_client(address=cluster_address).get_job_logs(job_id=job_id) diff --git a/providers/google/src/airflow/providers/google/cloud/links/ray.py b/providers/google/src/airflow/providers/google/cloud/links/ray.py new file mode 100644 index 0000000000000..c2ee498d76246 --- /dev/null +++ b/providers/google/src/airflow/providers/google/cloud/links/ray.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Context + +RAY_JOB_BASE_LINK = "http://{cluster_address}/#/jobs" +RAY_JOB_LINK = RAY_JOB_BASE_LINK + "/{job_id}" + + +class RayJobLink(BaseGoogleLink): + """Helper class for constructing Job on Ray cluster Link.""" + + name = "Ray Job" + key = "ray_job" + format_str = RAY_JOB_LINK + + @classmethod + def persist(cls, context: Context, **value): + cluster_address = value.get("cluster_address") + job_id = value.get("job_id") + super().persist( + context=context, + cluster_address=cluster_address, + job_id=job_id, + ) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/ray.py b/providers/google/src/airflow/providers/google/cloud/operators/ray.py new file mode 100644 index 0000000000000..c243c7d8e268f --- /dev/null +++ b/providers/google/src/airflow/providers/google/cloud/operators/ray.py @@ -0,0 +1,449 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains a Google Cloud Ray Job operators.""" + +from __future__ import annotations + +import time +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from google.api_core.exceptions import NotFound +from ray.dashboard.modules.job.common import JobStatus + +from airflow.providers.common.compat.sdk import AirflowNotFoundException, AirflowTaskTimeout +from airflow.providers.google.cloud.hooks.ray import RayJobHook +from airflow.providers.google.cloud.links.ray import RayJobLink +from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator + +if TYPE_CHECKING: + from airflow.providers.common.compat.sdk import Context + + +TERMINAL_STATUSES = {JobStatus.SUCCEEDED.value, JobStatus.FAILED.value} + + +class OperationFailedException(Exception): + """Custom exception to handle failing operations on Jobs.""" + + pass + + +class RayJobBaseOperator(GoogleCloudBaseOperator): + """ + Base class for Jobs on Ray operators. + + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = ( + "gcp_conn_id", + "impersonation_chain", + ) + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + @cached_property + def hook(self) -> RayJobHook: + return RayJobHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + +class RaySubmitJobOperator(RayJobBaseOperator): + """ + Submit and execute Job on Ray cluster. + + When a job is submitted, it runs once to completion or failure. Retries or + different runs with different parameters should be handled by the + submitter. Jobs are bound to the lifetime of a Ray cluster, so if the + cluster goes down, all running jobs on that cluster will be terminated. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://:10001), + or "auto", or "localhost:". + :param entrypoint: Required. The shell command to run for this job. + :param get_job_logs: If set to True, the operator will wait until the end of + Job execution and output the logs. + :param wait_for_job_done: If set to True, the operator will wait until the end of + Job execution. Please note, that if the Job will fail during execution and + this parameter is set to False, there will be no indication of the failure. + :param submission_id: A unique ID for this job. + :param runtime_env: The runtime environment to install and run this job in. + :param metadata: Arbitrary data to store along with this job. + :param entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_num_gpus: The quantity of GPUs to reserve for the execution + of the entrypoint command, separately from any tasks or actors launched + by it. Defaults to 0. + :param entrypoint_memory: The quantity of memory to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. Defaults to 0. + :param entrypoint_resources: The quantity of custom resources to reserve for the + execution of the entrypoint command, separately from any tasks or + actors launched by it. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"entrypoint", "submission_id", "cluster_address"} | set(RayJobBaseOperator.template_fields) + ) + operator_extra_links = (RayJobLink(),) + + def __init__( + self, + cluster_address: str, + entrypoint: str, + get_job_logs: bool | None = False, + wait_for_job_done: bool | None = False, + runtime_env: dict[str, Any] | None = None, + metadata: dict[str, str] | None = None, + submission_id: str | None = None, + entrypoint_num_cpus: int | float | None = None, + entrypoint_num_gpus: int | float | None = None, + entrypoint_memory: int | None = None, + entrypoint_resources: dict[str, float] | None = None, + submit_job_timeout: float = 60 * 30, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_address = cluster_address + self.get_job_logs = get_job_logs + self.wait_for_job_done = wait_for_job_done + self.entrypoint = entrypoint + self.runtime_env = runtime_env + self.metadata = metadata + self.submission_id = submission_id + self.entrypoint_num_cpus = entrypoint_num_cpus + self.entrypoint_num_gpus = entrypoint_num_gpus + self.entrypoint_memory = entrypoint_memory + self.entrypoint_resources = entrypoint_resources + self.submit_job_timeout = submit_job_timeout + + def _check_job_status( + self, cluster_address: str, job_id: str, timeout: float, polling_interval: float = 5.0 + ) -> str: + "Check if the Job has reached terminated state." + start = time.monotonic() + + while True: + job_status = self.hook.get_job_status(cluster_address=cluster_address, job_id=job_id) + if job_status in TERMINAL_STATUSES: + self.log.info("Job has finished execution with status: %s", job_status) + return job_status + + if time.monotonic() - start > timeout: + raise AirflowTaskTimeout( + f"Timeout waiting for Ray Job {job_id} to finish. Last status: {job_status}" + ) + + self.log.info("Job status: %s...", job_status) + time.sleep(polling_interval) + + def _get_job_logs(self, cluster_address, job_id): + "Output Job logs." + logs = self.hook.get_job_logs(cluster_address=cluster_address, job_id=job_id) + self.log.info("Got job logs:\n%s\n", logs) + + def execute(self, context: Context): + if self.get_job_logs and not self.wait_for_job_done: + raise ValueError( + "Retrieving Job logs can be possible only after Job completion. " + "Please, enable wait_for_job_done parameter to be able to get logs." + ) + try: + self.log.info("Submitting Job on a Ray cluster...") + submitted_job_id = self.hook.submit_job( + cluster_address=self.cluster_address, + entrypoint=self.entrypoint, + runtime_env=self.runtime_env, + metadata=self.metadata, + submission_id=self.submission_id, + entrypoint_num_cpus=self.entrypoint_num_cpus, + entrypoint_num_gpus=self.entrypoint_num_gpus, + entrypoint_memory=self.entrypoint_memory, + entrypoint_resources=self.entrypoint_resources, + ) + self.log.info("Submitted Ray Job id=%s", submitted_job_id) + RayJobLink.persist( + context=context, + cluster_address=self.cluster_address, + job_id=submitted_job_id, + ) + except RuntimeError as exc: + raise exc + if self.wait_for_job_done: + self._check_job_status( + cluster_address=self.cluster_address, job_id=submitted_job_id, timeout=self.submit_job_timeout + ) + + if self.get_job_logs: + self._get_job_logs(cluster_address=self.cluster_address, job_id=submitted_job_id) + + return submitted_job_id + + +class RayStopJobOperator(RayJobBaseOperator): + """ + Stop Job on Ray cluster. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://:10001), + or "auto", or "localhost:". + :param job_id: Required. The job ID or submission ID for the job to be stopped. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"cluster_address", "job_id"} | set(RayJobBaseOperator.template_fields) + ) + operator_extra_links = () + + def __init__( + self, + cluster_address: str, + job_id: str, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_address = cluster_address + self.job_id = job_id + + def execute(self, context: Context): + self.log.info("Stopping Job %s on a Ray cluster...", self.job_id) + + try: + is_stopped = self.hook.stop_job( + cluster_address=self.cluster_address, + job_id=self.job_id, + ) + except NotFound: + raise AirflowNotFoundException("Job with specified id was not found on cluster.") + + if is_stopped: + self.log.info("Job was successfully stopped.") + return + + raise OperationFailedException("Some error happened during stopping the Job. Exiting.") + + +class RayDeleteJobOperator(RayJobBaseOperator): + """ + Delete Job on Ray cluster in a terminal state and all of its associated data. + + If the job is not already in a terminal state, raises an error. + This does not delete the job logs from disk. + Submitting a job with the same submission ID as a previously + deleted job is not supported and may lead to unexpected behavior. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://:10001), + or "auto", or "localhost:". + :param job_id: Required. The job ID or submission ID for the job to be stopped. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"cluster_address", "job_id"} | set(RayJobBaseOperator.template_fields) + ) + operator_extra_links = () + + def __init__( + self, + cluster_address: str, + job_id: str, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_address = cluster_address + self.job_id = job_id + + def execute(self, context: Context): + self.log.info("Deleting Job %s on a Ray cluster...", self.job_id) + + try: + is_deleted = self.hook.delete_job( + cluster_address=self.cluster_address, + job_id=self.job_id, + ) + except NotFound: + raise AirflowNotFoundException("Job with specified id was not found on cluster.") + + if is_deleted: + self.log.info("Job was successfully deleted.") + return + + raise OperationFailedException("Some error happened during deleting the Job. Exiting.") + + +class RayGetJobInfoOperator(RayJobBaseOperator): + """ + Get the latest status and other information associated with a Job on Ray cluster. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://:10001), + or "auto", or "localhost:". + :param job_id: Required. The job ID or submission ID for the job to be stopped. + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple( + {"cluster_address", "job_id"} | set(RayJobBaseOperator.template_fields) + ) + operator_extra_links = () + + def __init__( + self, + cluster_address: str, + job_id: str, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_address = cluster_address + self.job_id = job_id + + def execute(self, context: Context): + self.log.info("Retrieving information about Job %s on a Ray cluster...", self.job_id) + + try: + job_info = self.hook.get_job_info( + cluster_address=self.cluster_address, + job_id=self.job_id, + ) + except NotFound: + raise AirflowNotFoundException("Job with specified id was not found on cluster.") + + # Everything below is outside the try/except so serialization/logging errors are not masked + self.log.info("Job information:\n %s \n", job_info) + + ray_job_dict = self.hook.serialize_job_obj(job_info) + return ray_job_dict + + +class RayListJobsOperator(RayJobBaseOperator): + """ + List all jobs along with their status and other information. + + Lists all jobs that have ever run on the cluster, including jobs that are + currently running and jobs that are no longer running. + + :param cluster_address: Required. Either (1) the address of the Ray cluster, or (2) the HTTP address + of the dashboard server on the head node, e.g. "http://:8265". + In case (1) it must be specified as an address that can be passed to + ray.init(), e.g. a Ray Client address (ray://:10001), + or "auto", or "localhost:". + :param gcp_conn_id: The connection ID to use connecting to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + template_fields: Sequence[str] = tuple({"cluster_address"} | set(RayJobBaseOperator.template_fields)) + operator_extra_links = () + + def __init__( + self, + cluster_address: str, + *args, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.cluster_address = cluster_address + + def execute(self, context: Context): + self.log.info("Listing Jobs on a Ray cluster...") + jobs = self.hook.list_jobs( + cluster_address=self.cluster_address, + ) + if jobs: + self.log.info("Outputting first 10 Jobs...") + for job in jobs[:10]: + self.log.info("Job:\n %s \n \n", job) + list_jobs = [self.hook.serialize_job_obj(job) for job in jobs] + return list_jobs + self.log.info("No Jobs found.") + return diff --git a/providers/google/src/airflow/providers/google/get_provider_info.py b/providers/google/src/airflow/providers/google/get_provider_info.py index 5296df9d23f6e..6c0a0a76d166b 100644 --- a/providers/google/src/airflow/providers/google/get_provider_info.py +++ b/providers/google/src/airflow/providers/google/get_provider_info.py @@ -474,6 +474,12 @@ def get_provider_info(): "how-to-guide": ["/docs/apache-airflow-providers-google/operators/cloud/gen_ai.rst"], "tags": ["gcp"], }, + { + "integration-name": "Google Ray", + "external-doc-url": "https://docs.cloud.google.com/vertex-ai/docs/open-source/ray-on-vertex-ai/overview", + "how-to-guide": ["/docs/apache-airflow-providers-google/operators/cloud/ray.rst"], + "tags": ["gcp"], + }, ], "operators": [ { @@ -691,6 +697,10 @@ def get_provider_info(): "integration-name": "Google Cloud Generative AI", "python-modules": ["airflow.providers.google.cloud.operators.gen_ai"], }, + { + "integration-name": "Google Ray", + "python-modules": ["airflow.providers.google.cloud.operators.ray"], + }, ], "sensors": [ { @@ -1055,6 +1065,10 @@ def get_provider_info(): "integration-name": "Google Cloud Generative AI", "python-modules": ["airflow.providers.google.cloud.hooks.gen_ai"], }, + { + "integration-name": "Google Ray", + "python-modules": ["airflow.providers.google.cloud.hooks.ray"], + }, ], "bundles": [ { @@ -1517,6 +1531,7 @@ def get_provider_info(): "airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterListLink", "airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaTopicLink", "airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaConsumerGroupLink", + "airflow.providers.google.cloud.links.ray.RayJobLink", ], "secrets-backends": [ "airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend" diff --git a/providers/google/tests/system/google/cloud/ray/__init__.py b/providers/google/tests/system/google/cloud/ray/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/google/tests/system/google/cloud/ray/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/google/tests/system/google/cloud/ray/example_ray_job.py b/providers/google/tests/system/google/cloud/ray/example_ray_job.py new file mode 100644 index 0000000000000..be3aa637b9599 --- /dev/null +++ b/providers/google/tests/system/google/cloud/ray/example_ray_job.py @@ -0,0 +1,168 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +Example Airflow DAG for Jobs on Ray operations. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.exceptions import AirflowOptionalProviderFeatureException +from airflow.models.dag import DAG +from airflow.providers.google.cloud.operators.ray import ( + RayDeleteJobOperator, + RayGetJobInfoOperator, + RayListJobsOperator, + RayStopJobOperator, + RaySubmitJobOperator, +) +from airflow.providers.google.cloud.operators.vertex_ai.ray import ( + CreateRayClusterOperator, + DeleteRayClusterOperator, + GetRayClusterOperator, +) + +try: + from google.cloud.aiplatform.vertex_ray.util import resources +except ImportError: + raise AirflowOptionalProviderFeatureException( + "The ray provider is optional and requires the `google-cloud-aiplatform` package to be installed. " + ) +try: + from airflow.sdk import TriggerRule +except ImportError: + # Compatibility for Airflow < 3.1 + from airflow.utils.trigger_rule import TriggerRule # type: ignore[no-redef,attr-defined] + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default") +DAG_ID = "ray_job_operations" +LOCATION = "us-central1" +JOB_ID = f"{DAG_ID}_{ENV_ID}".replace("-", "_") +WORKER_NODE_RESOURCES = resources.Resources( + node_count=1, +) + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + catchup=False, + render_template_as_native_obj=True, + tags=["example", "job", "ray"], +) as dag: + create_ray_cluster = CreateRayClusterOperator( + task_id="create_ray_cluster", + project_id=PROJECT_ID, + location=LOCATION, + worker_node_types=[WORKER_NODE_RESOURCES], + python_version="3.10", + ray_version="2.33", + ) + + get_ray_cluster = GetRayClusterOperator( + task_id="get_ray_cluster", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=create_ray_cluster.output["cluster_id"], + ) + + # [START how_to_ray_submit_job] + submit_ray_job = RaySubmitJobOperator( + task_id="submit_ray_job", + cluster_address="{{ task_instance.xcom_pull(task_ids='get_ray_cluster')['dashboard_address'] }}", + entrypoint="python3 heavy.py", + runtime_env={ + "working_dir": "./providers/google/tests/system/google/cloud/ray/resources", + "pip": [ + "ray==2.33.0", + ], + }, + get_job_logs=False, + wait_for_job_done=False, + submission_id=JOB_ID, + ) + # [END how_to_ray_submit_job] + + # [START how_to_ray_get_job_info] + info_ray_job = RayGetJobInfoOperator( + task_id="info_ray_job", + cluster_address="{{ task_instance.xcom_pull(task_ids='get_ray_cluster')['dashboard_address'] }}", + job_id=JOB_ID, + ) + # [END how_to_ray_get_job_info] + + # [START how_to_ray_list_jobs] + list_ray_job = RayListJobsOperator( + task_id="list_ray_job", + cluster_address="{{ task_instance.xcom_pull(task_ids='get_ray_cluster')['dashboard_address'] }}", + ) + # [END how_to_ray_list_jobs] + + # [START how_to_ray_stop_job] + stop_ray_job = RayStopJobOperator( + task_id="stop_ray_job", + job_id=JOB_ID, + cluster_address="{{ task_instance.xcom_pull(task_ids='get_ray_cluster')['dashboard_address'] }}", + ) + # [END how_to_ray_stop_job] + + # [START how_to_ray_delete_job] + delete_ray_job = RayDeleteJobOperator( + task_id="delete_ray_job", + cluster_address="{{ task_instance.xcom_pull(task_ids='get_ray_cluster')['dashboard_address'] }}", + job_id=JOB_ID, + trigger_rule=TriggerRule.ALL_DONE, + ) + # [END how_to_ray_delete_job] + + delete_ray_cluster = DeleteRayClusterOperator( + task_id="delete_ray_cluster", + project_id=PROJECT_ID, + location=LOCATION, + cluster_id=create_ray_cluster.output["cluster_id"], + trigger_rule=TriggerRule.ALL_DONE, + ) + + ( + create_ray_cluster + >> get_ray_cluster + >> submit_ray_job + >> info_ray_job + >> stop_ray_job + >> list_ray_job + >> delete_ray_job + >> delete_ray_cluster + ) + + # ### Everything below this line is not part of example ### + # ### Just for system tests purpose ### + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/providers/google/tests/system/google/cloud/ray/resources/__init__.py b/providers/google/tests/system/google/cloud/ray/resources/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/google/tests/system/google/cloud/ray/resources/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/google/tests/system/google/cloud/ray/resources/heavy.py b/providers/google/tests/system/google/cloud/ray/resources/heavy.py new file mode 100644 index 0000000000000..fb25b725a588d --- /dev/null +++ b/providers/google/tests/system/google/cloud/ray/resources/heavy.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import time + +import ray + +# Initialize Ray +ray.init() + + +# Define a computationally intensive task +@ray.remote(num_cpus=1) +def heavy_task(x): + """ + Simulates a heavy workload by performing a CPU-bound operation. + This example calculates the sum of squares for a range of numbers. + """ + total = 0 + for i in range(x): + total += i * i + time.sleep(1) # Simulate some work duration + return total + + +# Generate a large number of tasks +num_tasks = 1000 +results = [] +for _i in range(num_tasks): + results.append(heavy_task.remote(1000000)) + +# Retrieve results (this will trigger autoscaling if needed) +outputs = ray.get(results) + +# Print the sum of the results (optional) +print(f"Sum of results: {sum(outputs)}") + +# Terminate the process +ray.shutdown() diff --git a/providers/google/tests/unit/google/cloud/hooks/test_ray.py b/providers/google/tests/unit/google/cloud/hooks/test_ray.py new file mode 100644 index 0000000000000..e22c07204ac1f --- /dev/null +++ b/providers/google/tests/unit/google/cloud/hooks/test_ray.py @@ -0,0 +1,160 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from airflow.providers.google.cloud.hooks.ray import RayJobHook + +from unit.google.cloud.utils.base_gcp_mock import ( + mock_base_gcp_hook_default_project_id, +) + +TEST_GCP_CONN_ID: str = "test-gcp-conn-id" +TEST_CLUSTER_NAME: str = "test-cluster-name" + +BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" +RAY_JOB_STRING = "airflow.providers.google.cloud.hooks.ray.{}" + +TEST_CLUSTER_HOSTNAME = "ray.aiplatform-training.googleusercontent.com" +TEST_JOB_ID = "test-job-id" + + +class TestRayJobHook: + def setup_method(self): + with mock.patch( + BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id + ): + self.hook = RayJobHook(gcp_conn_id=TEST_GCP_CONN_ID) + self.hook.get_credentials = mock.MagicMock() + + @mock.patch(RAY_JOB_STRING.format("JobSubmissionClient")) + def test_submit_job(self, mock_client_cls) -> None: + mock_client = mock_client_cls.return_value + mock_client.submit_job.return_value = TEST_JOB_ID + job_id = self.hook.submit_job( + entrypoint="python3 heavy.py", + cluster_address=TEST_CLUSTER_HOSTNAME, + runtime_env={"a": "b"}, + metadata={"k": "v"}, + submission_id="sub-123", + entrypoint_num_cpus=1, + entrypoint_num_gpus=0, + entrypoint_memory=1024, + entrypoint_resources={"CPU": 1.0}, + ) + + mock_client_cls.assert_called_once_with(f"vertex_ray://{TEST_CLUSTER_HOSTNAME}") + mock_client.submit_job.assert_called_once_with( + entrypoint="python3 heavy.py", + runtime_env={"a": "b"}, + metadata={"k": "v"}, + submission_id="sub-123", + entrypoint_num_cpus=1, + entrypoint_num_gpus=0, + entrypoint_memory=1024, + entrypoint_resources={"CPU": 1.0}, + ) + assert job_id == TEST_JOB_ID + + @mock.patch(RAY_JOB_STRING.format("JobSubmissionClient")) + def test_stop_job(self, mock_client_cls) -> None: + mock_client = mock_client_cls.return_value + mock_client.stop_job.return_value = True + + result = self.hook.stop_job( + job_id=TEST_JOB_ID, + cluster_address=TEST_CLUSTER_HOSTNAME, + ) + + mock_client_cls.assert_called_once_with(f"vertex_ray://{TEST_CLUSTER_HOSTNAME}") + mock_client.stop_job.assert_called_once_with(job_id=TEST_JOB_ID) + assert result is True + + @mock.patch(RAY_JOB_STRING.format("JobSubmissionClient")) + def test_delete_job(self, mock_client_cls) -> None: + mock_client = mock_client_cls.return_value + mock_client.delete_job.return_value = True + + result = self.hook.delete_job( + job_id=TEST_JOB_ID, + cluster_address=TEST_CLUSTER_HOSTNAME, + ) + + mock_client_cls.assert_called_once_with(f"vertex_ray://{TEST_CLUSTER_HOSTNAME}") + mock_client.delete_job.assert_called_once_with(job_id=TEST_JOB_ID) + assert result is True + + @mock.patch(RAY_JOB_STRING.format("JobSubmissionClient")) + def test_get_job_info(self, mock_client_cls) -> None: + mock_client = mock_client_cls.return_value + fake_job_details = object() + mock_client.get_job_info.return_value = fake_job_details + + result = self.hook.get_job_info( + job_id=TEST_JOB_ID, + cluster_address=TEST_CLUSTER_HOSTNAME, + ) + + mock_client_cls.assert_called_once_with(f"vertex_ray://{TEST_CLUSTER_HOSTNAME}") + mock_client.get_job_info.assert_called_once_with(job_id=TEST_JOB_ID) + assert result is fake_job_details + + @mock.patch(RAY_JOB_STRING.format("JobSubmissionClient")) + def test_list_jobs(self, mock_client_cls) -> None: + mock_client = mock_client_cls.return_value + fake_list = [object(), object()] + mock_client.list_jobs.return_value = fake_list + + result = self.hook.list_jobs( + cluster_address=TEST_CLUSTER_HOSTNAME, + ) + + mock_client_cls.assert_called_once_with(f"vertex_ray://{TEST_CLUSTER_HOSTNAME}") + mock_client.list_jobs.assert_called_once_with() + assert result is fake_list + + @mock.patch(RAY_JOB_STRING.format("JobSubmissionClient")) + def test_get_job_status(self, mock_client_cls) -> None: + mock_client = mock_client_cls.return_value + fake_status = object() + mock_client.get_job_status.return_value = fake_status + + result = self.hook.get_job_status( + job_id=TEST_JOB_ID, + cluster_address=TEST_CLUSTER_HOSTNAME, + ) + + mock_client_cls.assert_called_once_with(f"vertex_ray://{TEST_CLUSTER_HOSTNAME}") + mock_client.get_job_status.assert_called_once_with(job_id=TEST_JOB_ID) + assert result is fake_status + + @mock.patch(RAY_JOB_STRING.format("JobSubmissionClient")) + def test_get_job_logs(self, mock_client_cls) -> None: + mock_client = mock_client_cls.return_value + fake_logs = "some logs" + mock_client.get_job_logs.return_value = fake_logs + + result = self.hook.get_job_logs( + job_id=TEST_JOB_ID, + cluster_address=TEST_CLUSTER_HOSTNAME, + ) + + mock_client_cls.assert_called_once_with(f"vertex_ray://{TEST_CLUSTER_HOSTNAME}") + mock_client.get_job_logs.assert_called_once_with(job_id=TEST_JOB_ID) + assert result == fake_logs diff --git a/providers/google/tests/unit/google/cloud/links/test_ray.py b/providers/google/tests/unit/google/cloud/links/test_ray.py new file mode 100644 index 0000000000000..e7b51b9932231 --- /dev/null +++ b/providers/google/tests/unit/google/cloud/links/test_ray.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from airflow.providers.google.cloud.links.ray import RayJobLink + +TEST_CLUSTER_ADDRESS = "ray-head-123.us-central1.ray.googleusercontent.com" +TEST_JOB_ID = "test-job-id" + +EXPECTED_RAY_JOB_LINK_NAME = "Ray Job" +EXPECTED_RAY_JOB_LINK_KEY = "ray_job" +EXPECTED_RAY_JOB_LINK_FORMAT_STR = "http://{cluster_address}/#/jobs/{job_id}" + + +class TestRayJobLink: + def test_class_attributes(self): + assert RayJobLink.key == EXPECTED_RAY_JOB_LINK_KEY + assert RayJobLink.name == EXPECTED_RAY_JOB_LINK_NAME + assert RayJobLink.format_str == EXPECTED_RAY_JOB_LINK_FORMAT_STR + + def test_persist(self): + mock_context = mock.MagicMock() + mock_context["ti"] = mock.MagicMock() + mock_context["task"] = mock.MagicMock() + + RayJobLink.persist( + context=mock_context, + cluster_address=TEST_CLUSTER_ADDRESS, + job_id=TEST_JOB_ID, + ) + + mock_context["ti"].xcom_push.assert_called_once_with( + key=EXPECTED_RAY_JOB_LINK_KEY, + value={ + "cluster_address": TEST_CLUSTER_ADDRESS, + "job_id": TEST_JOB_ID, + }, + ) diff --git a/providers/google/tests/unit/google/cloud/operators/test_ray.py b/providers/google/tests/unit/google/cloud/operators/test_ray.py new file mode 100644 index 0000000000000..3256925803a10 --- /dev/null +++ b/providers/google/tests/unit/google/cloud/operators/test_ray.py @@ -0,0 +1,334 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest +from ray.dashboard.modules.job.common import JobStatus + +from airflow.providers.common.compat.sdk import AirflowTaskTimeout +from airflow.providers.google.cloud.operators.ray import ( + RayDeleteJobOperator, + RayGetJobInfoOperator, + RayListJobsOperator, + RayStopJobOperator, + RaySubmitJobOperator, +) + +TASK_ID = "test-task" +GCP_CONN_ID = "test-gcp-conn-id" +IMPERSONATION_CHAIN = "test-impersonation" +CLUSTER_ADDRESS = "ray-head-123.us-central1.ray.googleusercontent.com" +ENTRYPOINT = "python3 heavy.py" +SUBMISSION_ID = "submission-123" +JOB_ID = "job-123" + +RAY_OP_PATH = "airflow.providers.google.cloud.operators.ray.{}" + + +class TestRaySubmitJobOperator: + @mock.patch(RAY_OP_PATH.format("RayJobHook")) + def test_execute_submits_job_and_persists_link(self, mock_hook_cls): + mock_hook = mock_hook_cls.return_value + mock_hook.submit_job.return_value = JOB_ID + + op = RaySubmitJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + cluster_address=CLUSTER_ADDRESS, + entrypoint=ENTRYPOINT, + wait_for_job_done=False, + get_job_logs=False, + runtime_env={"k": "v"}, + metadata={"m": "v"}, + submission_id=SUBMISSION_ID, + entrypoint_num_cpus=1, + entrypoint_num_gpus=0, + entrypoint_memory=1024, + entrypoint_resources={"CPU": 1.0}, + ) + + ti_mock = mock.MagicMock() + context = {"ti": ti_mock, "task": mock.MagicMock()} + + op.execute(context=context) + + mock_hook_cls.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + mock_hook.submit_job.assert_called_once_with( + cluster_address=CLUSTER_ADDRESS, + entrypoint=ENTRYPOINT, + runtime_env={"k": "v"}, + metadata={"m": "v"}, + submission_id=SUBMISSION_ID, + entrypoint_num_cpus=1, + entrypoint_num_gpus=0, + entrypoint_memory=1024, + entrypoint_resources={"CPU": 1.0}, + ) + ti_mock.xcom_push.assert_called_once_with( + key="ray_job", + value={ + "cluster_address": CLUSTER_ADDRESS, + "job_id": JOB_ID, + }, + ) + mock_hook.get_job_status.assert_not_called() + mock_hook.get_job_logs.assert_not_called() + + def test_execute_raises_if_logs_without_wait(self): + op = RaySubmitJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + cluster_address=CLUSTER_ADDRESS, + entrypoint=ENTRYPOINT, + wait_for_job_done=False, + get_job_logs=True, + ) + + ti_mock = mock.MagicMock() + context = {"ti": ti_mock, "task": mock.MagicMock()} + + with pytest.raises(ValueError, match="Retrieving Job logs can be possible only after Job completion"): + op.execute(context=context) + + @mock.patch(RAY_OP_PATH.format("RayJobHook")) + def test_execute_waits_for_job_and_gets_logs(self, mock_hook_cls): + mock_hook = mock_hook_cls.return_value + mock_hook.submit_job.return_value = JOB_ID + mock_hook.get_job_status.return_value = JobStatus.SUCCEEDED + mock_hook.get_job_logs.return_value = "some logs" + + op = RaySubmitJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + cluster_address=CLUSTER_ADDRESS, + entrypoint=ENTRYPOINT, + wait_for_job_done=True, + get_job_logs=True, + ) + + ti_mock = mock.MagicMock() + context = {"ti": ti_mock, "task": mock.MagicMock()} + + op.execute(context=context) + mock_hook.submit_job.assert_called_once_with( + cluster_address=CLUSTER_ADDRESS, + entrypoint=ENTRYPOINT, + runtime_env=None, + metadata=None, + submission_id=None, + entrypoint_num_cpus=None, + entrypoint_num_gpus=None, + entrypoint_memory=None, + entrypoint_resources=None, + ) + mock_hook.get_job_status.assert_called_with( + cluster_address=CLUSTER_ADDRESS, + job_id=JOB_ID, + ) + + mock_hook.get_job_logs.assert_called_once_with( + cluster_address=CLUSTER_ADDRESS, + job_id=JOB_ID, + ) + + ti_mock.xcom_push.assert_called_once_with( + key="ray_job", + value={ + "cluster_address": CLUSTER_ADDRESS, + "job_id": JOB_ID, + }, + ) + + @mock.patch(RAY_OP_PATH.format("time.sleep")) + @mock.patch(RAY_OP_PATH.format("RayJobHook")) + def test_check_job_status_reaches_terminal(self, mock_hook_cls, mock_sleep): + mock_hook = mock_hook_cls.return_value + mock_hook.stop_job.return_value = True + + operator = RaySubmitJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + cluster_address=CLUSTER_ADDRESS, + entrypoint=ENTRYPOINT, + wait_for_job_done=True, + get_job_logs=True, + ) + operator.hook.get_job_status = mock.MagicMock( + side_effect=[JobStatus.RUNNING, JobStatus.RUNNING, JobStatus.SUCCEEDED] + ) + status = operator._check_job_status("addr", "job", polling_interval=1, timeout=100) + + assert status == JobStatus.SUCCEEDED + assert mock_sleep.call_count == 2 + + @mock.patch(RAY_OP_PATH.format("time.sleep")) + @mock.patch(RAY_OP_PATH.format("time.monotonic")) + @mock.patch(RAY_OP_PATH.format("RayJobHook")) + def test_check_job_status_timeout(self, mock_hook_cls, mock_monotonic, mock_sleep): + mock_hook = mock_hook_cls.return_value + mock_hook.stop_job.return_value = True + mock_monotonic.side_effect = [0, 10, 20, 1000] + operator = RaySubmitJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + cluster_address=CLUSTER_ADDRESS, + entrypoint=ENTRYPOINT, + wait_for_job_done=True, + get_job_logs=True, + ) + operator.hook.get_job_status = mock.MagicMock(return_value=JobStatus.RUNNING) + + with pytest.raises( + AirflowTaskTimeout, match=r"Timeout waiting for Ray Job job to finish. Last status: RUNNING" + ): + operator._check_job_status("addr", "job", polling_interval=1, timeout=30) + + +class TestRayStopJobOperator: + @mock.patch(RAY_OP_PATH.format("RayJobHook")) + def test_execute_stops_job_successfully(self, mock_hook_cls): + mock_hook = mock_hook_cls.return_value + mock_hook.stop_job.return_value = True + + op = RayStopJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + cluster_address=CLUSTER_ADDRESS, + job_id=JOB_ID, + ) + + context = {"ti": mock.MagicMock(), "task": mock.MagicMock()} + result = op.execute(context=context) + + mock_hook_cls.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.stop_job.assert_called_once_with( + cluster_address=CLUSTER_ADDRESS, + job_id=JOB_ID, + ) + assert result is None + + +class TestRayDeleteJobOperator: + @mock.patch(RAY_OP_PATH.format("RayJobHook")) + def test_execute_deletes_job_successfully(self, mock_hook_cls): + mock_hook = mock_hook_cls.return_value + mock_hook.delete_job.return_value = True + + op = RayDeleteJobOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + cluster_address=CLUSTER_ADDRESS, + job_id=JOB_ID, + ) + + context = {"ti": mock.MagicMock(), "task": mock.MagicMock()} + result = op.execute(context=context) + + mock_hook_cls.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.delete_job.assert_called_once_with( + cluster_address=CLUSTER_ADDRESS, + job_id=JOB_ID, + ) + assert result is None + + +class TestRayGetJobInfoOperator: + @mock.patch(RAY_OP_PATH.format("RayJobHook")) + def test_execute_returns_serialized_job_info(self, mock_hook_cls): + mock_hook = mock_hook_cls.return_value + fake_job_obj = object() + fake_serialized = {"job_id": JOB_ID, "status": "SUCCEEDED"} + + mock_hook.get_job_info.return_value = fake_job_obj + mock_hook.serialize_job_obj.return_value = fake_serialized + + op = RayGetJobInfoOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + cluster_address=CLUSTER_ADDRESS, + job_id=JOB_ID, + ) + + context = {"ti": mock.MagicMock(), "task": mock.MagicMock()} + result = op.execute(context=context) + + mock_hook_cls.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.get_job_info.assert_called_once_with( + cluster_address=CLUSTER_ADDRESS, + job_id=JOB_ID, + ) + mock_hook.serialize_job_obj.assert_called_once_with(fake_job_obj) + assert result == fake_serialized + + +class TestRayListJobsOperator: + @mock.patch(RAY_OP_PATH.format("RayJobHook")) + def test_execute_lists_and_serializes_jobs(self, mock_hook_cls): + mock_hook = mock_hook_cls.return_value + job1 = object() + job2 = object() + mock_hook.list_jobs.return_value = [job1, job2] + mock_hook.serialize_job_obj.side_effect = [ + {"job_id": "job-1"}, + {"job_id": "job-2"}, + ] + + op = RayListJobsOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + cluster_address=CLUSTER_ADDRESS, + ) + + context = {"ti": mock.MagicMock(), "task": mock.MagicMock()} + result = op.execute(context=context) + + mock_hook_cls.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.list_jobs.assert_called_once_with( + cluster_address=CLUSTER_ADDRESS, + ) + mock_hook.serialize_job_obj.assert_has_calls([mock.call(job1), mock.call(job2)]) + assert result == [ + {"job_id": "job-1"}, + {"job_id": "job-2"}, + ]