diff --git a/.github/workflows/add-to-project.yml b/.github/workflows/add-to-project.yml index 6133f8b0..387e9869 100644 --- a/.github/workflows/add-to-project.yml +++ b/.github/workflows/add-to-project.yml @@ -11,7 +11,7 @@ jobs: name: Add issue to project runs-on: ubuntu-latest steps: - - uses: tibdex/github-app-token@v1 + - uses: tibdex/github-app-token@v2 id: generate-token name: Generate GitHub token with: diff --git a/.github/workflows/template-sync.yml b/.github/workflows/template-sync.yml index ab537a9f..58e310c6 100644 --- a/.github/workflows/template-sync.yml +++ b/.github/workflows/template-sync.yml @@ -22,7 +22,7 @@ jobs: - name: Perform updates run: cruft update -y - - uses: tibdex/github-app-token@v1 + - uses: tibdex/github-app-token@v2 id: generate-token name: Generate GitHub token with: diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f98b1fc..cb7a939d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,13 +20,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Security +## 0.4.7 + +Released September 22nd, 2023. + +### Added + +- Vertex AI `CustomJob` worker - [#211](https://github.com/PrefectHQ/prefect-gcp/pull/211) +- Add `kill_infrastructure` method to Vertex AI worker - [#213](https://github.com/PrefectHQ/prefect-gcp/pull/213) + +### Changed + +- Use flow run name for name of created custom jobs - [#208](https://github.com/PrefectHQ/prefect-gcp/pull/208) + ## 0.4.6 -Not yet released +Released September 5th, 2023. ### Changed -- Vertex AI CustomJob sets labels specified by Prefect Agent when Deployment triggered on infrastructure. +- Persist Labels to Vertex AI Custom Job - [#198](https://github.com/PrefectHQ/prefect-gcp/pull/208) ## 0.4.5 diff --git a/docs/cloud_run_worker.md b/docs/cloud_run_worker.md new file mode 100644 index 00000000..8382c2df --- /dev/null +++ b/docs/cloud_run_worker.md @@ -0,0 +1 @@ +::: prefect_gcp.workers.cloud_run diff --git a/docs/vertex_worker.md b/docs/vertex_worker.md new file mode 100644 index 00000000..1f816a4f --- /dev/null +++ b/docs/vertex_worker.md @@ -0,0 +1 @@ +::: prefect_gcp.workers.vertex diff --git a/docs/worker.md b/docs/worker.md deleted file mode 100644 index 4bede70f..00000000 --- a/docs/worker.md +++ /dev/null @@ -1 +0,0 @@ -::: prefect_gcp.worker diff --git a/mkdocs.yml b/mkdocs.yml index f5d9670b..5f45af38 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -85,7 +85,9 @@ nav: - Cloud Run: cloud_run.md - AI Platform: aiplatform.md - Deployment Steps: deployments/steps.md - - Worker: worker.md + - Workers: + - Cloud Run: cloud_run_worker.md + - Vertex AI: vertex_worker.md extra: social: diff --git a/prefect_gcp/__init__.py b/prefect_gcp/__init__.py index 24f23177..5e616175 100644 --- a/prefect_gcp/__init__.py +++ b/prefect_gcp/__init__.py @@ -1,19 +1,23 @@ -from . import _version +from prefect._internal.compatibility.deprecated import ( + register_renamed_module, +) -from .bigquery import BigQueryWarehouse # noqa +from . import _version from .aiplatform import VertexAICustomTrainingJob # noqa -from .cloud_storage import GcsBucket # noqa +from .bigquery import BigQueryWarehouse # noqa from .cloud_run import CloudRunJob # noqa -from .secret_manager import GcpSecret # noqa +from .cloud_storage import GcsBucket # noqa from .credentials import GcpCredentials # noqa -from .worker import CloudRunWorker # noqa -from prefect._internal.compatibility.deprecated import ( - register_renamed_module, -) +from .secret_manager import GcpSecret # noqa +from .workers.vertex import VertexAIWorker # noqa +from .workers.cloud_run import CloudRunWorker # noqa register_renamed_module( "prefect_gcp.projects", "prefect_gcp.deployments", start_date="Jun 2023" ) +register_renamed_module( + "prefect_gcp.worker", "prefect_gcp.workers", start_date="Sep 2023" +) __version__ = _version.get_versions()["version"] diff --git a/prefect_gcp/aiplatform.py b/prefect_gcp/aiplatform.py index ee169dc9..130754b2 100644 --- a/prefect_gcp/aiplatform.py +++ b/prefect_gcp/aiplatform.py @@ -102,7 +102,7 @@ class VertexAICustomTrainingJob(Infrastructure): _block_type_name = "Vertex AI Custom Training Job" _block_type_slug = "vertex-ai-custom-training-job" - _logo_url = "https://images.ctfassets.net/gm98wzqotmnx/4CD4wwbiIKPkZDt4U3TEuW/c112fe85653da054b6d5334ef662bec4/gcp.png?h=250" # noqa + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/10424e311932e31c477ac2b9ef3d53cefbaad708-250x250.png" # noqa _documentation_url = "https://prefecthq.github.io/prefect-gcp/aiplatform/#prefect_gcp.aiplatform.VertexAICustomTrainingJob" # noqa: E501 type: Literal["vertex-ai-custom-training-job"] = Field( @@ -184,7 +184,6 @@ class VertexAICustomTrainingJob(Infrastructure): "and required if a service account cannot be detected in gcp_credentials." ), ) - job_watch_poll_interval: float = Field( default=5.0, description=( @@ -200,17 +199,14 @@ def job_name(self): https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.CustomJob#google_cloud_aiplatform_CustomJob_display_name """ # noqa try: - repo_name = self.image.split("/")[2] # `gcr.io///`" + base_name = self.name or self.image.split("/")[2] + return f"{base_name}-{uuid4().hex}" except IndexError: raise ValueError( "The provided image must be from either Google Container Registry " "or Google Artifact Registry" ) - unique_suffix = uuid4().hex - job_name = f"{repo_name}-{unique_suffix}" - return job_name - def _get_compatible_labels(self) -> Dict[str, str]: """ Ensures labels are compatible with GCP label requirements. @@ -430,6 +426,7 @@ async def run( raise RuntimeError(f"{self._log_prefix}: {error_msg}") status_code = 0 if final_job_run.state == JobState.JOB_STATE_SUCCEEDED else 1 + return VertexAICustomTrainingJobResult( identifier=final_job_run.display_name, status_code=status_code ) diff --git a/prefect_gcp/bigquery.py b/prefect_gcp/bigquery.py index b51c1881..52c48895 100644 --- a/prefect_gcp/bigquery.py +++ b/prefect_gcp/bigquery.py @@ -552,7 +552,7 @@ class BigQueryWarehouse(DatabaseBlock): """ # noqa _block_type_name = "BigQuery Warehouse" - _logo_url = "https://images.ctfassets.net/gm98wzqotmnx/4CD4wwbiIKPkZDt4U3TEuW/c112fe85653da054b6d5334ef662bec4/gcp.png?h=250" # noqa + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/10424e311932e31c477ac2b9ef3d53cefbaad708-250x250.png" # noqa _documentation_url = "https://prefecthq.github.io/prefect-gcp/bigquery/#prefect_gcp.bigquery.BigQueryWarehouse" # noqa: E501 gcp_credentials: GcpCredentials diff --git a/prefect_gcp/cloud_run.py b/prefect_gcp/cloud_run.py index f335b95d..ff500fb0 100644 --- a/prefect_gcp/cloud_run.py +++ b/prefect_gcp/cloud_run.py @@ -213,7 +213,7 @@ class CloudRunJob(Infrastructure): _block_type_slug = "cloud-run-job" _block_type_name = "GCP Cloud Run Job" _description = "Infrastructure block used to run GCP Cloud Run Jobs. Note this block is experimental. The interface may change without notice." # noqa - _logo_url = "https://images.ctfassets.net/gm98wzqotmnx/4CD4wwbiIKPkZDt4U3TEuW/c112fe85653da054b6d5334ef662bec4/gcp.png?h=250" # noqa + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/10424e311932e31c477ac2b9ef3d53cefbaad708-250x250.png" # noqa _documentation_url = "https://prefecthq.github.io/prefect-gcp/cloud_run/#prefect_gcp.cloud_run.CloudRunJob" # noqa: E501 type: Literal["cloud-run-job"] = Field( diff --git a/prefect_gcp/cloud_storage.py b/prefect_gcp/cloud_storage.py index 1675bc8d..06b8ef01 100644 --- a/prefect_gcp/cloud_storage.py +++ b/prefect_gcp/cloud_storage.py @@ -564,7 +564,7 @@ class GcsBucket(WritableDeploymentStorage, WritableFileSystem, ObjectStorageBloc ``` """ - _logo_url = "https://images.ctfassets.net/gm98wzqotmnx/4CD4wwbiIKPkZDt4U3TEuW/c112fe85653da054b6d5334ef662bec4/gcp.png?h=250" # noqa + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/10424e311932e31c477ac2b9ef3d53cefbaad708-250x250.png" # noqa _block_type_name = "GCS Bucket" _documentation_url = "https://prefecthq.github.io/prefect-gcp/cloud_storage/#prefect_gcp.cloud_storage.GcsBucket" # noqa: E501 diff --git a/prefect_gcp/credentials.py b/prefect_gcp/credentials.py index 3ca10f89..8900fae5 100644 --- a/prefect_gcp/credentials.py +++ b/prefect_gcp/credentials.py @@ -96,7 +96,7 @@ class GcpCredentials(CredentialsBlock): ``` """ # noqa - _logo_url = "https://images.ctfassets.net/gm98wzqotmnx/4CD4wwbiIKPkZDt4U3TEuW/c112fe85653da054b6d5334ef662bec4/gcp.png?h=250" # noqa + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/10424e311932e31c477ac2b9ef3d53cefbaad708-250x250.png" # noqa _block_type_name = "GCP Credentials" _documentation_url = "https://prefecthq.github.io/prefect-gcp/credentials/#prefect_gcp.credentials.GcpCredentials" # noqa: E501 diff --git a/prefect_gcp/secret_manager.py b/prefect_gcp/secret_manager.py index c5561fd0..1132c15a 100644 --- a/prefect_gcp/secret_manager.py +++ b/prefect_gcp/secret_manager.py @@ -300,7 +300,7 @@ class GcpSecret(SecretBlock): secret_version: Version number of the secret to use, or "latest". """ - _logo_url = "https://images.ctfassets.net/gm98wzqotmnx/4CD4wwbiIKPkZDt4U3TEuW/c112fe85653da054b6d5334ef662bec4/gcp.png?h=250" # noqa + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/10424e311932e31c477ac2b9ef3d53cefbaad708-250x250.png" # noqa _documentation_url = "https://prefecthq.github.io/prefect-gcp/secret_manager/#prefect_gcp.secret_manager.GcpSecret" # noqa: E501 gcp_credentials: GcpCredentials diff --git a/prefect_gcp/workers/__init__.py b/prefect_gcp/workers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/prefect_gcp/worker.py b/prefect_gcp/workers/cloud_run.py similarity index 99% rename from prefect_gcp/worker.py rename to prefect_gcp/workers/cloud_run.py index d6aaf3ca..e3824d9b 100644 --- a/prefect_gcp/worker.py +++ b/prefect_gcp/workers/cloud_run.py @@ -519,8 +519,8 @@ class CloudRunWorker(BaseWorker): "a Google Cloud Platform account." ) _display_name = "Google Cloud Run" - _documentation_url = "https://prefecthq.github.io/prefect-gcp/worker/" - _logo_url = "https://images.ctfassets.net/gm98wzqotmnx/4SpnOBvMYkHp6z939MDKP6/549a91bc1ce9afd4fb12c68db7b68106/social-icon-google-cloud-1200-630.png?h=250" # noqa + _documentation_url = "https://prefecthq.github.io/prefect-gcp/cloud_run_worker/" + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/10424e311932e31c477ac2b9ef3d53cefbaad708-250x250.png" # noqa def _create_job_error(self, exc, configuration): """Provides a nicer error for 404s when trying to create a Cloud Run Job.""" diff --git a/prefect_gcp/workers/vertex.py b/prefect_gcp/workers/vertex.py new file mode 100644 index 00000000..cbb12e0f --- /dev/null +++ b/prefect_gcp/workers/vertex.py @@ -0,0 +1,643 @@ +""" + +Module containing the custom worker used for executing flow runs as Vertex AI Custom Jobs. + +Get started by creating a Cloud Run work pool: + +```bash +prefect work-pool create 'my-vertex-pool' --type vertex-ai +``` + +Then start a Cloud Run worker with the following command: + +```bash +prefect worker start --pool 'my-vertex-pool' +``` + +## Configuration +Read more about configuring work pools +[here](https://docs.prefect.io/latest/concepts/work-pools/#work-pool-overview). +""" +import datetime +import re +import shlex +import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from uuid import uuid4 + +import anyio +from prefect.exceptions import InfrastructureNotFound +from prefect.logging.loggers import PrefectLogAdapter +from prefect.utilities.asyncutils import run_sync_in_worker_thread +from prefect.utilities.pydantic import JsonPatch +from prefect.workers.base import ( + BaseJobConfiguration, + BaseVariables, + BaseWorker, + BaseWorkerResult, +) +from pydantic import Field, validator +from slugify import slugify + +from prefect_gcp.credentials import GcpCredentials + +# to prevent "Failed to load collection" from surfacing +# if google-cloud-aiplatform is not installed +try: + from google.api_core.client_options import ClientOptions + from google.cloud.aiplatform.gapic import JobServiceClient + from google.cloud.aiplatform_v1.types.custom_job import ( + ContainerSpec, + CustomJob, + CustomJobSpec, + Scheduling, + WorkerPoolSpec, + ) + from google.cloud.aiplatform_v1.types.job_service import CancelCustomJobRequest + from google.cloud.aiplatform_v1.types.job_state import JobState + from google.cloud.aiplatform_v1.types.machine_resources import DiskSpec, MachineSpec + from google.protobuf.duration_pb2 import Duration + from tenacity import retry, stop_after_attempt, wait_fixed, wait_random +except ModuleNotFoundError: + pass + +_DISALLOWED_GCP_LABEL_CHARACTERS = re.compile(r"[^-a-zA-Z0-9_]+") + +if TYPE_CHECKING: + from prefect.client.schemas import FlowRun + from prefect.server.schemas.core import Flow + from prefect.server.schemas.responses import DeploymentResponse + + +class VertexAIWorkerVariables(BaseVariables): + """ + Default variables for the Vertex AI worker. + + The schema for this class is used to populate the `variables` section of the default + base job template. + """ + + region: str = Field( + description="The region where the Vertex AI Job resides.", + example="us-central1", + ) + image: str = Field( + title="Image Name", + description=( + "The URI of a container image in the Container or Artifact Registry, " + "used to run your Vertex AI Job. Note that Vertex AI will need access" + "to the project and region where the container image is stored. See " + "https://cloud.google.com/vertex-ai/docs/training/create-custom-container" + ), + example="gcr.io/your-project/your-repo:latest", + ) + credentials: Optional[GcpCredentials] = Field( + title="GCP Credentials", + default_factory=GcpCredentials, + description="The GCP Credentials used to initiate the " + "Vertex AI Job. If not provided credentials will be " + "inferred from the local environment.", + ) + machine_type: str = Field( + title="Machine Type", + description=( + "The machine type to use for the run, which controls " + "the available CPU and memory. " + "See https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec" + ), + default="n1-standard-4", + ) + accelerator_type: Optional[str] = Field( + title="Accelerator Type", + description=( + "The type of accelerator to attach to the machine. " + "See https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec" + ), + example="NVIDIA_TESLA_K80", + default=None, + ) + accelerator_count: Optional[int] = Field( + title="Accelerator Count", + description=( + "The number of accelerators to attach to the machine. " + "See https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec" + ), + example=1, + default=None, + ) + boot_disk_type: str = Field( + title="Boot Disk Type", + description="The type of boot disk to attach to the machine.", + default="pd-ssd", + ) + boot_disk_size_gb: int = Field( + title="Boot Disk Size (GB)", + description="The size of the boot disk to attach to the machine, in gigabytes.", + default=100, + ) + maximum_run_time_hours: int = Field( + default=1, + title="Maximum Run Time (Hours)", + description="The maximum job running time, in hours", + ) + network: Optional[str] = Field( + default=None, + title="Network", + description="The full name of the Compute Engine network" + "to which the Job should be peered. Private services access must " + "already be configured for the network. If left unspecified, the job " + "is not peered with any network. " + "For example: projects/12345/global/networks/myVPC", + ) + reserved_ip_ranges: Optional[List[str]] = Field( + default=None, + title="Reserved IP Ranges", + description="A list of names for the reserved ip ranges under the VPC " + "network that can be used for this job. If set, we will deploy the job " + "within the provided ip ranges. Otherwise, the job will be deployed to " + "any ip ranges under the provided VPC network.", + ) + service_account_name: Optional[str] = Field( + default=None, + title="Service Account Name", + description=( + "Specifies the service account to use " + "as the run-as account in Vertex AI. The worker submitting jobs must have " + "act-as permission on this run-as account. If unspecified, the AI " + "Platform Custom Code Service Agent for the CustomJob's project is " + "used. Takes precedence over the service account found in GCP credentials, " + "and required if a service account cannot be detected in GCP credentials." + ), + ) + job_watch_poll_interval: float = Field( + default=5.0, + title="Poll Interval (Seconds)", + description=( + "The amount of time to wait between GCP API calls while monitoring the " + "state of a Vertex AI Job." + ), + ) + + +def _get_base_job_spec() -> Dict[str, Any]: + """Returns a base job body to use for job spec validation. + Note that the values are stubbed and are not used for the actual job.""" + return { + "maximum_run_time_hours": "1", + "worker_pool_specs": [ + { + "replica_count": 1, + "container_spec": { + "image_uri": "gcr.io/your-project/your-repo:latest", + }, + "machine_spec": { + "machine_type": "n1-standard-4", + }, + "disk_spec": { + "boot_disk_type": "pd-ssd", + "boot_disk_size_gb": "100", + }, + } + ], + } + + +class VertexAIWorkerJobConfiguration(BaseJobConfiguration): + """ + Configuration class used by the Vertex AI Worker to create a Job. + + An instance of this class is passed to the Vertex AI Worker's `run` method + for each flow run. It contains all information necessary to execute + the flow run as a Vertex AI Job. + + Attributes: + region: The region where the Vertex AI Job resides. + credentials: The GCP Credentials used to connect to Vertex AI. + job_spec: The Vertex AI Job spec used to create the Job. + job_watch_poll_interval: The interval between GCP API calls to check Job state. + """ + + region: str = Field( + description="The region where the Vertex AI Job resides.", + example="us-central1", + ) + credentials: Optional[GcpCredentials] = Field( + title="GCP Credentials", + default_factory=GcpCredentials, + description="The GCP Credentials used to initiate the " + "Vertex AI Job. If not provided credentials will be " + "inferred from the local environment.", + ) + + job_spec: Dict[str, Any] = Field( + template={ + "service_account_name": "{{ service_account_name }}", + "network": "{{ network }}", + "reserved_ip_ranges": "{{ reserved_ip_ranges }}", + "maximum_run_time_hours": "{{ maximum_run_time_hours }}", + "worker_pool_specs": [ + { + "replica_count": 1, + "container_spec": { + "image_uri": "{{ image }}", + "command": "{{ command }}", + "args": [], + }, + "machine_spec": { + "machine_type": "{{ machine_type }}", + "accelerator_type": "{{ accelerator_type }}", + "accelerator_count": "{{ accelerator_count }}", + }, + "disk_spec": { + "boot_disk_type": "{{ boot_disk_type }}", + "boot_disk_size_gb": "{{ boot_disk_size_gb }}", + }, + } + ], + } + ) + job_watch_poll_interval: float = Field( + default=5.0, + title="Poll Interval (Seconds)", + description=( + "The amount of time to wait between GCP API calls while monitoring the " + "state of a Vertex AI Job." + ), + ) + + @property + def project(self) -> str: + """property for accessing the project from the credentials.""" + return self.credentials.project + + @property + def job_name(self) -> str: + """ + The name can be up to 128 characters long and can be consist of any UTF-8 characters. Reference: + https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.CustomJob#google_cloud_aiplatform_CustomJob_display_name + """ # noqa + unique_suffix = uuid4().hex + job_name = f"{self.name}-{unique_suffix}" + return job_name + + def prepare_for_flow_run( + self, + flow_run: "FlowRun", + deployment: Optional["DeploymentResponse"] = None, + flow: Optional["Flow"] = None, + ): + super().prepare_for_flow_run(flow_run, deployment, flow) + + self._inject_formatted_env_vars() + self._inject_formatted_command() + self._ensure_existence_of_service_account() + + def _inject_formatted_env_vars(self): + """Inject environment variables in the Vertex job_spec configuration, + in the correct format, which is sourced from the BaseJobConfiguration. + This method is invoked by `prepare_for_flow_run()`.""" + worker_pool_specs = self.job_spec["worker_pool_specs"] + formatted_env_vars = [ + {"name": key, "value": value} for key, value in self.env.items() + ] + worker_pool_specs[0]["container_spec"]["env"] = formatted_env_vars + + def _inject_formatted_command(self): + """Inject shell commands in the Vertex job_spec configuration, + in the correct format, which is sourced from the BaseJobConfiguration. + Here, we'll ensure that the default string format + is converted to a list of strings.""" + worker_pool_specs = self.job_spec["worker_pool_specs"] + + existing_command = worker_pool_specs[0]["container_spec"].get("command") + if existing_command is None: + worker_pool_specs[0]["container_spec"]["command"] = [ + "python", + "-m", + "prefect.engine", + ] + elif isinstance(existing_command, str): + worker_pool_specs[0]["container_spec"]["command"] = shlex.split( + existing_command + ) + + def _ensure_existence_of_service_account(self): + """Verify that a service account was provided, either in the credentials + or as a standalone service account name override.""" + + provided_service_account_name = self.job_spec.get("service_account_name") + credential_service_account = self.credentials._service_account_email + + service_account_to_use = ( + provided_service_account_name or credential_service_account + ) + + if service_account_to_use is None: + raise ValueError( + "A service account is required for the Vertex job. " + "A service account could not be detected in the attached credentials " + "or in the service_account_name input. " + "Please pass in valid GCP credentials or a valid service_account_name" + ) + + self.job_spec["service_account_name"] = service_account_to_use + + @validator("job_spec") + def _ensure_job_spec_includes_required_attributes(cls, value: Dict[str, Any]): + """ + Ensures that the job spec includes all required components. + """ + patch = JsonPatch.from_diff(value, _get_base_job_spec()) + missing_paths = sorted([op["path"] for op in patch if op["op"] == "add"]) + if missing_paths: + raise ValueError( + "Job is missing required attributes at the following paths: " + f"{', '.join(missing_paths)}" + ) + return value + + +class VertexAIWorkerResult(BaseWorkerResult): + """Contains information about the final state of a completed process""" + + +class VertexAIWorker(BaseWorker): + """Prefect worker that executes flow runs within Vertex AI Jobs.""" + + type = "vertex-ai" + job_configuration = VertexAIWorkerJobConfiguration + job_configuration_variables = VertexAIWorkerVariables + _description = ( + "Execute flow runs within containers on Google Vertex AI. Requires " + "a Google Cloud Platform account." + ) + _display_name = "Google Vertex AI" + _documentation_url = "https://prefecthq.github.io/prefect-gcp/vertex_worker/" + _logo_url = "https://cdn.sanity.io/images/3ugk85nk/production/10424e311932e31c477ac2b9ef3d53cefbaad708-250x250.png" # noqa + + async def run( + self, + flow_run: "FlowRun", + configuration: VertexAIWorkerJobConfiguration, + task_status: Optional[anyio.abc.TaskStatus] = None, + ) -> VertexAIWorkerResult: + """ + Executes a flow run within a Vertex AI Job and waits for the flow run + to complete. + + Args: + flow_run: The flow run to execute + configuration: The configuration to use when executing the flow run. + task_status: The task status object for the current flow run. If provided, + the task will be marked as started. + + Returns: + VertexAIWorkerResult: A result object containing information about the + final state of the flow run + """ + logger = self.get_flow_run_logger(flow_run) + + client_options = ClientOptions( + api_endpoint=f"{configuration.region}-aiplatform.googleapis.com" + ) + + job_name = configuration.job_name + + job_spec = self._build_job_spec(configuration) + with configuration.credentials.get_job_service_client( + client_options=client_options + ) as job_service_client: + job_run = await self._create_and_begin_job( + job_name, job_spec, job_service_client, configuration, logger + ) + + if task_status: + task_status.started(job_run.name) + + final_job_run = await self._watch_job_run( + job_name=job_name, + full_job_name=job_run.name, + job_service_client=job_service_client, + current_state=job_run.state, + until_states=( + JobState.JOB_STATE_SUCCEEDED, + JobState.JOB_STATE_FAILED, + JobState.JOB_STATE_CANCELLED, + JobState.JOB_STATE_EXPIRED, + ), + configuration=configuration, + logger=logger, + timeout=int( + datetime.timedelta( + hours=configuration.job_spec["maximum_run_time_hours"] + ).total_seconds() + ), + ) + + error_msg = final_job_run.error.message + + # Vertex will include an error message upon valid + # flow cancellations, so we'll avoid raising an error in that case + if error_msg and "CANCELED" not in error_msg: + raise RuntimeError(error_msg) + + status_code = 0 if final_job_run.state == JobState.JOB_STATE_SUCCEEDED else 1 + + return VertexAIWorkerResult( + identifier=final_job_run.display_name, status_code=status_code + ) + + def _build_job_spec( + self, configuration: VertexAIWorkerJobConfiguration + ) -> "CustomJobSpec": + """ + Builds a job spec by gathering details. + """ + # here, we extract the `worker_pool_specs` out of the job_spec + worker_pool_specs = [ + WorkerPoolSpec( + container_spec=ContainerSpec(**spec["container_spec"]), + machine_spec=MachineSpec(**spec["machine_spec"]), + replica_count=spec["replica_count"], + disk_spec=DiskSpec(**spec["disk_spec"]), + ) + for spec in configuration.job_spec.pop("worker_pool_specs", []) + ] + + timeout = Duration().FromTimedelta( + td=datetime.timedelta( + hours=configuration.job_spec["maximum_run_time_hours"] + ) + ) + scheduling = Scheduling(timeout=timeout) + + # construct the final job spec that we will provide to Vertex AI + job_spec = CustomJobSpec( + worker_pool_specs=worker_pool_specs, + scheduling=scheduling, + ignore_unknown_fields=True, + **configuration.job_spec, + ) + return job_spec + + async def _create_and_begin_job( + self, + job_name: str, + job_spec: "CustomJobSpec", + job_service_client: "JobServiceClient", + configuration: VertexAIWorkerJobConfiguration, + logger: PrefectLogAdapter, + ) -> "CustomJob": + """ + Builds a custom job and begins running it. + """ + # create custom job + custom_job = CustomJob( + display_name=job_name, + job_spec=job_spec, + labels=self._get_compatible_labels(configuration=configuration), + ) + + # run job + logger.info(f"Job {job_name!r} starting to run ") + + project = configuration.project + resource_name = f"projects/{project}/locations/{configuration.region}" + + retry_policy = retry( + stop=stop_after_attempt(3), wait=wait_fixed(1) + wait_random(0, 3) + ) + + custom_job_run = await run_sync_in_worker_thread( + retry_policy(job_service_client.create_custom_job), + parent=resource_name, + custom_job=custom_job, + ) + + logger.info( + f"Job {job_name!r} has successfully started; " + f"the full job name is {custom_job_run.name!r}" + ) + + return custom_job_run + + async def _watch_job_run( + self, + job_name: str, + full_job_name: str, # different from job_name + job_service_client: "JobServiceClient", + current_state: "JobState", + until_states: Tuple["JobState"], + configuration: VertexAIWorkerJobConfiguration, + logger: PrefectLogAdapter, + timeout: int = None, + ) -> "CustomJob": + """ + Polls job run to see if status changed. + """ + state = JobState.JOB_STATE_UNSPECIFIED + last_state = current_state + t0 = time.time() + + while state not in until_states: + job_run = await run_sync_in_worker_thread( + job_service_client.get_custom_job, + name=full_job_name, + ) + state = job_run.state + if state != last_state: + state_label = ( + state.name.replace("_", " ") + .lower() + .replace("state", "state is now:") + ) + # results in "New job state is now: succeeded" + logger.info(f"{job_name} has new {state_label}") + last_state = state + else: + # Intermittently, the job will not be described. We want to respect the + # watch timeout though. + logger.debug(f"Job {job_name} not found.") + + elapsed_time = time.time() - t0 + if timeout is not None and elapsed_time > timeout: + raise RuntimeError( + f"Timed out after {elapsed_time}s while watching job for states " + "{until_states!r}" + ) + time.sleep(configuration.job_watch_poll_interval) + + return job_run + + def _get_compatible_labels( + self, configuration: VertexAIWorkerJobConfiguration + ) -> Dict[str, str]: + """ + Ensures labels are compatible with GCP label requirements. + https://cloud.google.com/resource-manager/docs/creating-managing-labels + + Ex: the Prefect provided key of prefect.io/flow-name -> prefect-io_flow-name + """ + compatible_labels = {} + for key, val in configuration.labels.items(): + new_key = slugify( + key, + lowercase=True, + replacements=[("/", "_"), (".", "-")], + max_length=63, + regex_pattern=_DISALLOWED_GCP_LABEL_CHARACTERS, + ) + compatible_labels[new_key] = slugify( + val, + lowercase=True, + replacements=[("/", "_"), (".", "-")], + max_length=63, + regex_pattern=_DISALLOWED_GCP_LABEL_CHARACTERS, + ) + return compatible_labels + + async def kill_infrastructure( + self, + infrastructure_pid: str, + configuration: VertexAIWorkerJobConfiguration, + grace_seconds: int = 30, + ): + """ + Stops a job running in Vertex AI upon flow cancellation, + based on the provided infrastructure PID + run configuration. + """ + if grace_seconds != 30: + self._logger.warning( + f"Kill grace period of {grace_seconds}s requested, but GCP does not " + "support dynamic grace period configuration. See here for more info: " + "https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.customJobs/cancel" # noqa + ) + + client_options = ClientOptions( + api_endpoint=f"{configuration.region}-aiplatform.googleapis.com" + ) + with configuration.credentials.get_job_service_client( + client_options=client_options + ) as job_service_client: + await run_sync_in_worker_thread( + self._stop_job, + client=job_service_client, + vertex_job_name=infrastructure_pid, + ) + + def _stop_job(self, client: "JobServiceClient", vertex_job_name: str): + """ + Calls the `cancel_custom_job` method on the Vertex AI Job Service Client. + """ + cancel_custom_job_request = CancelCustomJobRequest(name=vertex_job_name) + try: + client.cancel_custom_job( + request=cancel_custom_job_request, + ) + except Exception as exc: + if "does not exist" in str(exc): + raise InfrastructureNotFound( + f"Cannot stop Vertex AI job; the job name {vertex_job_name!r} " + "could not be found." + ) from exc + raise diff --git a/tests/test_worker.py b/tests/test_cloud_run_worker.py similarity index 99% rename from tests/test_worker.py rename to tests/test_cloud_run_worker.py index da1be94e..119e8f2f 100644 --- a/tests/test_worker.py +++ b/tests/test_cloud_run_worker.py @@ -11,7 +11,7 @@ from prefect.server.schemas.actions import DeploymentCreate from prefect_gcp.credentials import GcpCredentials -from prefect_gcp.worker import ( +from prefect_gcp.workers.cloud_run import ( CloudRunWorker, CloudRunWorkerJobConfiguration, CloudRunWorkerResult, @@ -354,7 +354,7 @@ def get_mock_client(*args, **kwargs): return m monkeypatch.setattr( - "prefect_gcp.worker.CloudRunWorker._get_client", + "prefect_gcp.workers.cloud_run.CloudRunWorker._get_client", get_mock_client, ) diff --git a/tests/test_vertex_worker.py b/tests/test_vertex_worker.py new file mode 100644 index 00000000..11d73a50 --- /dev/null +++ b/tests/test_vertex_worker.py @@ -0,0 +1,264 @@ +import uuid +from types import SimpleNamespace +from unittest.mock import MagicMock + +import anyio +import pydantic +import pytest +from google.cloud.aiplatform_v1.types.job_service import CancelCustomJobRequest +from google.cloud.aiplatform_v1.types.job_state import JobState +from prefect.client.schemas import FlowRun +from prefect.exceptions import InfrastructureNotFound + +from prefect_gcp.workers.vertex import ( + VertexAIWorker, + VertexAIWorkerJobConfiguration, + VertexAIWorkerResult, +) + + +@pytest.fixture +def job_config(service_account_info, gcp_credentials): + return VertexAIWorkerJobConfiguration( + name="my-custom-ai-job", + region="ashenvale", + credentials=gcp_credentials, + job_spec={ + "maximum_run_time_hours": 1, + "worker_pool_specs": [ + { + "replica_count": 1, + "container_spec": { + "image_uri": "gcr.io/your-project/your-repo:latest", + "command": ["python", "-m", "prefect.engine"], + }, + "machine_spec": { + "machine_type": "n1-standard-4", + "accelerator_type": "NVIDIA_TESLA_K80", + "accelerator_count": 1, + }, + "disk_spec": { + "boot_disk_type": "pd-ssd", + "boot_disk_size_gb": 100, + }, + } + ], + }, + ) + + +@pytest.fixture +def flow_run(): + return FlowRun(flow_id=uuid.uuid4(), name="my-flow-run-name") + + +class TestVertexAIWorkerJobConfiguration: + async def test_validate_empty_job_spec(self, gcp_credentials): + base_job_template = VertexAIWorker.get_default_base_job_template() + base_job_template["job_configuration"]["job_spec"] = {} + base_job_template["job_configuration"]["region"] = "us-central1" + + with pytest.raises(pydantic.ValidationError) as excinfo: + await VertexAIWorkerJobConfiguration.from_template_and_values( + base_job_template, {"credentials": gcp_credentials} + ) + + assert excinfo.value.errors() == [ + { + "loc": ("job_spec",), + "msg": ( + "Job is missing required attributes at the following paths: " + "/maximum_run_time_hours, /worker_pool_specs" + ), + "type": "value_error", + } + ] + + async def test_validate_incomplete_worker_pool_spec(self, gcp_credentials): + base_job_template = VertexAIWorker.get_default_base_job_template() + base_job_template["job_configuration"]["job_spec"] = { + "maximum_run_time_hours": 1, + "worker_pool_specs": [ + { + "replica_count": 1, + "container_spec": {"command": ["some", "command"]}, + "machine_spec": { + "accelerator_type": "NVIDIA_TESLA_K80", + }, + }, + ], + } + base_job_template["job_configuration"]["region"] = "us-central1" + + with pytest.raises(pydantic.ValidationError) as excinfo: + await VertexAIWorkerJobConfiguration.from_template_and_values( + base_job_template, {"credentials": gcp_credentials} + ) + + assert excinfo.value.errors() == [ + { + "loc": ("job_spec",), + "msg": ( + "Job is missing required attributes at the following paths: " + "/worker_pool_specs/0/container_spec/image_uri, " + "/worker_pool_specs/0/disk_spec, " + "/worker_pool_specs/0/machine_spec/machine_type" + ), + "type": "value_error", + } + ] + + def test_gcp_project(self, job_config: VertexAIWorkerJobConfiguration): + assert job_config.project == "gcp_credentials_project" + + def test_job_name(self, flow_run, job_config: VertexAIWorkerJobConfiguration): + job_config.prepare_for_flow_run(flow_run, None, None) + assert job_config.job_name.startswith("my-custom-ai-job") + + job_config.name = None + job_config.prepare_for_flow_run(flow_run, None, None) + assert job_config.job_name.startswith("my-flow-run-name") + + async def test_missing_service_account(self, flow_run, job_config): + job_config.job_spec["service_account_name"] = None + job_config.credentials._service_account_email = None + + with pytest.raises( + ValueError, match="A service account is required for the Vertex job" + ): + job_config.prepare_for_flow_run(flow_run, None, None) + + def test_valid_command_formatting( + self, flow_run, job_config: VertexAIWorkerJobConfiguration + ): + job_config.prepare_for_flow_run(flow_run, None, None) + assert ["python", "-m", "prefect.engine"] == job_config.job_spec[ + "worker_pool_specs" + ][0]["container_spec"]["command"] + + job_config.job_spec["worker_pool_specs"][0]["container_spec"][ + "command" + ] = "echo -n hello" + job_config.prepare_for_flow_run(flow_run, None, None) + assert ["echo", "-n", "hello"] == job_config.job_spec["worker_pool_specs"][0][ + "container_spec" + ]["command"] + + +class TestVertexAIWorker: + async def test_successful_worker_run(self, flow_run, job_config): + async with VertexAIWorker("test-pool") as worker: + job_config.prepare_for_flow_run(flow_run, None, None) + result = await worker.run(flow_run=flow_run, configuration=job_config) + assert ( + job_config.credentials.job_service_client.create_custom_job.call_count + == 1 + ) + assert ( + job_config.credentials.job_service_client.get_custom_job.call_count == 1 + ) + assert result == VertexAIWorkerResult( + status_code=0, identifier="mock_display_name" + ) + + async def test_failed_worker_run(self, flow_run, job_config): + job_config.prepare_for_flow_run(flow_run, None, None) + error_msg = "something went kablooey" + error_job_display_name = "catastrophization" + job_config.credentials.job_service_client.get_custom_job.return_value = ( + MagicMock( + name="error_mock_name", + state=JobState.JOB_STATE_FAILED, + error=MagicMock(message=error_msg), + display_name=error_job_display_name, + ) + ) + async with VertexAIWorker("test-pool") as worker: + with pytest.raises(RuntimeError, match=error_msg): + await worker.run(flow_run=flow_run, configuration=job_config) + + assert ( + job_config.credentials.job_service_client.create_custom_job.call_count + == 1 + ) + assert ( + job_config.credentials.job_service_client.get_custom_job.call_count == 1 + ) + + async def test_cancelled_worker_run(self, flow_run, job_config): + job_config.prepare_for_flow_run(flow_run, None, None) + job_display_name = "a-job-well-done" + job_config.credentials.job_service_client.get_custom_job.return_value = ( + MagicMock( + name="cancelled_mock_name", + state=JobState.JOB_STATE_CANCELLED, + error=MagicMock(message=""), + display_name=job_display_name, + ) + ) + async with VertexAIWorker("test-pool") as worker: + result = await worker.run(flow_run=flow_run, configuration=job_config) + assert ( + job_config.credentials.job_service_client.create_custom_job.call_count + == 1 + ) + assert ( + job_config.credentials.job_service_client.get_custom_job.call_count == 1 + ) + assert result == VertexAIWorkerResult( + status_code=1, identifier=job_display_name + ) + + async def test_kill_infrastructure(self, flow_run, job_config): + mock = job_config.credentials.job_service_client.create_custom_job + # the CancelCustomJobRequest class seems to reject a MagicMock value + # so here, we'll use a SimpleNamespace as the mocked return values + mock.return_value = SimpleNamespace( + name="foobar", state=JobState.JOB_STATE_PENDING + ) + + async with VertexAIWorker("test-pool") as worker: + with anyio.fail_after(10): + async with anyio.create_task_group() as tg: + result = await tg.start(worker.run, flow_run, job_config) + await worker.kill_infrastructure(result, job_config) + + mock = job_config.credentials.job_service_client.cancel_custom_job + assert mock.call_count == 1 + mock.assert_called_with(request=CancelCustomJobRequest(name="foobar")) + + async def test_kill_infrastructure_no_grace_seconds( + self, flow_run, job_config, caplog + ): + mock = job_config.credentials.job_service_client.create_custom_job + mock.return_value = SimpleNamespace( + name="bazzbar", state=JobState.JOB_STATE_PENDING + ) + async with VertexAIWorker("test-pool") as worker: + + input_grace_period = 32 + + with anyio.fail_after(10): + async with anyio.create_task_group() as tg: + identifier = await tg.start(worker.run, flow_run, job_config) + await worker.kill_infrastructure( + identifier, job_config, input_grace_period + ) + for record in caplog.records: + if ( + f"Kill grace period of {input_grace_period}s " + "requested, but GCP does not" + ) in record.msg: + break + else: + raise AssertionError("Expected message not found.") + + async def test_kill_infrastructure_not_found(self, job_config): + async with VertexAIWorker("test-pool") as worker: + job_config.credentials.job_service_client.cancel_custom_job.side_effect = ( + Exception("does not exist") + ) + with pytest.raises( + InfrastructureNotFound, match="Cannot stop Vertex AI job" + ): + await worker.kill_infrastructure("foobarbazz", job_config)