Skip to content

Commit

Permalink
feat: labels class
Browse files Browse the repository at this point in the history
  • Loading branch information
javfg committed Sep 16, 2024
1 parent 98dbd13 commit f34ee7b
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 64 deletions.
44 changes: 18 additions & 26 deletions src/ot_orchestration/operators/gce.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@
LoggingServiceV2AsyncClient,
)

from ot_orchestration.utils.common import (
GCP_PROJECT_PLATFORM,
GCP_REGION,
prepare_labels,
)
from ot_orchestration.utils.common import GCP_PROJECT_PLATFORM, GCP_ZONE
from ot_orchestration.utils.labels import Labels

CONTAINER_NAME = "workload_container"
LOGGING_REQUEST_INTERVAL = 5
Expand Down Expand Up @@ -289,11 +286,7 @@ class ComputeEngineRunContainerizedWorkloadSensor(BaseSensorOperator):
If set to None or missing, the default project_id for platform is used (GCP_PROJECT_PLATFORM).
zone: The zone where the instance will be created (default is GCP_ZONE).
instance_name: Name of the instance name that will run the workload.
labels: Optional dict of labels to apply to the instance on top of the default ones, which
are `team: open-targets`, `product: platform`, `environment: development` or `production`,
and `created_by: unified-orchestrator`. Refer to `controlled vocabularies
<https://github.com/opentargets/controlled-vocabularies/blob/main/infrastructure.yaml>`__
for more information.
labels: Labels to apply to the instance. See the `Labels` class for more information.
container_image: Container image to run.
container_command: Command to run inside the container (optional).
container_args: Arguments to pass to the container (optional).
Expand Down Expand Up @@ -339,15 +332,15 @@ def __init__(
self,
*,
project: str = GCP_PROJECT_PLATFORM,
zone: str = f"{GCP_REGION}-b",
zone: str = GCP_ZONE,
instance_name: str,
labels: dict[str, str] = {},
labels: Labels | None = None,
container_image: str,
container_command: str = "",
container_args: list[str] | None = None,
container_env: dict[str, str] | None = None,
container_service_account: str = "default",
container_scopes: list[str] = [],
container_scopes: list[str] | None = None,
container_files: dict[str, str] | None = None,
machine_type: str = "c3d-standard-8",
work_disk_size_gb: int = 0,
Expand All @@ -363,13 +356,13 @@ def __init__(
self.project = project
self.zone = zone
self.instance_name = instance_name
self.labels = labels
self.labels = labels if labels else Labels()
self.container_image = container_image
self.container_command = container_command
self.container_args = container_args
self.container_env = container_env
self.container_service_account = container_service_account
self.container_scopes = container_scopes
self.container_scopes = container_scopes or []
self.container_files = container_files
self.machine_type = machine_type
self.gcp_conn_id = gcp_conn_id
Expand Down Expand Up @@ -463,15 +456,13 @@ def declare_instance(self) -> compute_v1.InstanceTemplate:
- Network configuration.
- Service account and scopes.
"""
labels = prepare_labels(self.labels, self.project)

boot_disk = compute_v1.AttachedDisk(
auto_delete=True,
boot=True,
initialize_params=compute_v1.AttachedDiskInitializeParams(
disk_type=f"zones/{self.zone}/diskTypes/pd-ssd",
labels=labels,
source_image="projects/cos-cloud/global/images/cos-stable-113-18244-151-9",
labels=self.labels.get(),
source_image="projects/cos-cloud/global/images/cos-113-18244-151-50",
),
)

Expand All @@ -480,7 +471,7 @@ def declare_instance(self) -> compute_v1.InstanceTemplate:
device_name="work-disk",
initialize_params=compute_v1.AttachedDiskInitializeParams(
disk_size_gb=self.work_disk_size_gb,
labels=labels,
labels=self.labels.get(),
disk_type=f"zones/{self.zone}/diskTypes/pd-ssd",
),
)
Expand All @@ -492,7 +483,7 @@ def declare_instance(self) -> compute_v1.InstanceTemplate:
description="unified orchestrator runner instance",
machine_type=f"zones/{self.zone}/machineTypes/{self.machine_type}",
disks=disks,
labels=labels,
labels=self.labels.get(),
metadata=types.Metadata(
items=[
{
Expand Down Expand Up @@ -531,8 +522,8 @@ def declare_instance(self) -> compute_v1.InstanceTemplate:
"https://www.googleapis.com/auth/servicecontrol",
"https://www.googleapis.com/auth/service.management.readonly",
"https://www.googleapis.com/auth/trace.append",
*self.container_scopes,
]
+ self.container_scopes
),
),
],
Expand Down Expand Up @@ -572,17 +563,18 @@ def copy_machine_logs(self) -> None:
page_size=1000,
)
for entry in entries:
self.log.log(level=logging.INFO, msg=entry.payload.get("message", ""))
self.log.info(entry.payload.get("message", "Empty log message"))

def poke(self, context: Context) -> bool:
"""Check if the instance is still running in a synchronous way."""
# We must implement this if we want to run this sensor in a non-deferrable mode.
return NotImplementedError
return False

def execute(self, context: Context) -> bool:
"""Set up and execute the sensor, then start the trigger."""
run = context.get("params", {}).get("run_label", context.get("dag_run").run_id)
self.labels.add({"run": run})
self.start()
self.log.info("Instance created, now checking for the exit code.")

if not self.deferrable:
super().execute(context)
Expand Down Expand Up @@ -724,7 +716,7 @@ async def run(self):
await asyncio.sleep(self.poll_sleep)
except Exception as e:
self.log.error("Error occurred while checking startup script exit code.")
yield TriggerEvent({"status": "error", "message": f"{type(e)}: {str(e)}"})
yield TriggerEvent({"status": "error", "message": f"{type(e)}: {e!r}"})

@cached_property
def hook(self) -> CloudLoggingAsyncHook:
Expand Down
2 changes: 2 additions & 0 deletions src/ot_orchestration/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
clean_name,
read_hocon_config,
read_yaml_config,
strhash,
time_to_seconds,
to_hocon,
to_yaml,
Expand All @@ -31,6 +32,7 @@
"time_to_seconds",
"to_hocon",
"to_yaml",
"strhash",
"create_task_spec",
"create_batch_job",
"extract_study_id_from_path",
Expand Down
36 changes: 8 additions & 28 deletions src/ot_orchestration/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@

from __future__ import annotations

from typing import Any

import pendulum

from ot_orchestration.utils.utils import clean_label, strhash
from ot_orchestration.utils import strhash

GENTROPY_VERSION = "0.0.0"

# Cloud configuration.
GCP_PROJECT = "open-targets-genetics-dev"
GCP_PROJECT_PLATFORM = "open-targets-eu-dev"
GCP_PROJECT_ZONE = "europe-west1-b"
GCP_REGION = "europe-west1"
GCP_ZONE = "europe-west1-d"
GCP_DATAPROC_IMAGE = "2.1"
Expand Down Expand Up @@ -47,46 +50,23 @@
}

platform_dag_kwargs = {
"dag_id": "platform_pipeline",
"description": "Open Targets Platform",
"catchup": False,
"schedule": None,
"start_date": pendulum.now(tz="Europe/London").subtract(days=1),
"tags": [
"platform",
"experimental",
],
"tags": ["platform", "experimental"],
"user_defined_filters": {"strhash": strhash},
}

platform_shared_labels = lambda project: {
shared_labels = lambda project: {
"team": "open-targets",
"subteam": "backend",
"product": "platform",
"environment": "development" if "dev" in project else "production",
"created_by": "unified-orchestrator",
}


def prepare_labels(
custom_labels: dict[str, str] = {},
project: str = GCP_PROJECT_PLATFORM,
) -> dict[str, str]:
"""Prepare labels for use in google cloud.
Includes a set of default labels, and ensures that all labels are
correctly formatted.
note: To use outside platform, a way to override the "product" label should
be added.
Args:
custom_labels (dict[str, str]): Custom labels to add to the default labels.
project (str): The name of the project. Defaults to GCP_PROJECT_PLATFORM.
"""
labels = platform_shared_labels(project)
labels.update(custom_labels)

return {k: clean_label(v) for k, v in labels.items()}


def convert_params_to_hydra_positional_arg(
step: dict[str, dict[str, Any]],
) -> list[str] | None:
Expand Down
57 changes: 57 additions & 0 deletions src/ot_orchestration/utils/labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Labels for resources in Google Cloud."""

import re

from ot_orchestration.utils.common import GCP_PROJECT_PLATFORM, shared_labels


class Labels:
"""A collection of labels for Google Cloud resources.
Includes a set of default labels, and ensures that all labels are correctly
formatted.
Refer to the `controlled vocabularies <https://github.com/opentargets/controlled-vocabularies/blob/main/infrastructure.yaml>`__
repository for a list of example values.
See the `shared_labels` dict in `common.py` module for the default labels.
Args:
extra: A dict of extra labels to add on top of the defaults.
repository for a list of valid values. Defaults to "platform".
project: The GCP project to use for the labels. This will determine the
content of the default "environment" label. Defaults to
GCP_PROJECT_PLATFORM.
"""

def __init__(
self,
extra: dict[str, str] | None = None,
project: str = GCP_PROJECT_PLATFORM,
) -> None:
self.project = project
self.extra = extra or {}
self.label_dict = shared_labels(project)
self.label_dict.update({k: self.clean_label(v) for k, v in self.extra.items()})

def clean_label(self, label: str) -> str:
"""Clean a label for use in google cloud.
According to the docs: The value can only contain lowercase letters, numeric
characters, underscores and dashes. The value can be at most 63 characters
long.
"""
return re.sub(r"[^a-z0-9-_]", "-", label.lower())[0:63]

def add(self, extra: dict[str, str]) -> None:
"""Add labels to a collection."""
self.label_dict.update({k: self.clean_label(v) for k, v in extra.items()})

def get(self) -> dict[str, str]:
"""Return a dict of clean labels."""
return self.label_dict

def clone(self, extra: dict[str, str] | None = None) -> "Labels":
"""Return a copy with additional labels."""
extra = extra or {}
return Labels({**self.label_dict, **extra}, self.project)
10 changes: 0 additions & 10 deletions src/ot_orchestration/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,6 @@ def check_gcp_folder_exists(bucket_name: str, folder_path: str) -> bool:
return any(blobs)


def clean_label(label: str) -> str:
"""Clean a label for use in google cloud.
According to the docs: The value can only contain lowercase letters, numeric
characters, underscores and dashes. The value can be at most 63 characters
long.
"""
return re.sub(r"[^a-z0-9-_]", "-", label.lower())[0:63]


def clean_name(name: str) -> str:
"""Create a clean name meeting google cloud naming conventions."""
return re.sub(r"[^a-z0-9-]", "-", name.lower())
Expand Down

0 comments on commit f34ee7b

Please sign in to comment.