Skip to content

Commit

Permalink
Fix empty log pager bug and old timestamps for exit code
Browse files Browse the repository at this point in the history
  • Loading branch information
javfg committed Aug 30, 2024
1 parent c824bb6 commit dee178a
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 18 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ module = [
]
ignore_missing_imports = true


[tool.ruff.lint]
ignore = [
"E501", # line too long
Expand Down
31 changes: 30 additions & 1 deletion src/ot_orchestration/common_airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pendulum

from ot_orchestration.utils.utils import strhash
from ot_orchestration.utils.utils import clean_label, strhash

# Code version. It has to be repeated here as well as in `pyproject.toml`, because Airflow isn't able to look at files outside of its `dags/` directory.
GENTROPY_VERSION = "0.0.0"
Expand Down Expand Up @@ -57,3 +57,32 @@
],
"user_defined_filters": {"strhash": strhash},
}

platform_shared_labels = lambda project: {
"team": "open-targets",
"subteam": "backend",
"product": "platform",
"environment": "development" if "dev" in project else "production",
"created_by": "unified-orchestrator",
}


def prepare_labels(
custom_labels: dict[str, str] = {},
project: str = GCP_PROJECT_PLATFORM,
) -> dict[str, str]:
"""Prepare labels for use in google cloud.
Includes a set of default labels, and ensures that all labels are
correctly formatted.
note: To use outside platform, a way to override the "product" label should
be added.
Args:
custom_labels (dict[str, str]): Custom labels to add to the default labels.
project (str): The name of the project. Defaults to GCP_PROJECT_PLATFORM.
"""
labels = platform_shared_labels(project)
labels.update(custom_labels)

return {k: clean_label(v) for k, v in labels.items()}
46 changes: 30 additions & 16 deletions src/ot_orchestration/operators/gce.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Custom sensor that runs a containerized workload on a Google Compute Engine instance."""

import asyncio
import datetime
import logging
import time
from functools import cached_property
Expand All @@ -24,8 +25,11 @@
LoggingServiceV2AsyncClient,
)

from ot_orchestration.common_airflow import GCP_PROJECT_PLATFORM, GCP_REGION
from ot_orchestration.utils.utils import clean_label
from ot_orchestration.common_airflow import (
GCP_PROJECT_PLATFORM,
GCP_REGION,
prepare_labels,
)

CONTAINER_NAME = "workload_container"
LOGGING_REQUEST_INTERVAL = 5
Expand Down Expand Up @@ -98,6 +102,7 @@ def __init__(self, log: logging.Logger, *args, **kwargs):
self.request_interval = LOGGING_REQUEST_INTERVAL

def list_entries(self, *args, **kwargs):
"""List log entries and retries request that get rate-limited."""
entries = None

while True:
Expand Down Expand Up @@ -143,6 +148,7 @@ def __init__(
self.api_version = api_version

def get_conn(self) -> RateLimitedLoggingClient:
"""Return the Google Cloud Logging service client."""
if self._client is None:
self._client = RateLimitedLoggingClient(
log=self.log,
Expand Down Expand Up @@ -181,6 +187,7 @@ def __init__(
self.request_interval = LOGGING_REQUEST_INTERVAL

def get_conn(self) -> LoggingServiceV2AsyncClient:
"""Return the Google Cloud Logging service client."""
if self._client is None:
self._client = LoggingServiceV2AsyncClient(
credentials=self.get_credentials(),
Expand All @@ -192,6 +199,7 @@ async def get_exit_code(
self,
project_name: str,
instance_name: str,
initial_timestamp: datetime.datetime,
) -> int:
"""Get the exit code of the startup script of a Google Compute Engine instance.
Expand All @@ -216,8 +224,8 @@ async def get_exit_code(
Script "startup-script" failed with error: exit status 1
"""
client = self.get_conn()

query = f'resource.type="gce_instance" labels.instance_name="{instance_name}" jsonPayload.message=~"startup-script[\w\\\":\s]*exit status [0-9]+"' # fmt: skip
timestamp_str = initial_timestamp.isoformat()
query = f'resource.type="gce_instance" labels.instance_name="{instance_name}" timestamp>"{timestamp_str}" jsonPayload.message=~"startup-script[\w\\\":\s]*exit status [0-9]+"' # fmt: skip
log_pages = None

while True:
Expand All @@ -244,11 +252,17 @@ async def get_exit_code(
await asyncio.sleep(self.request_interval)
self.request_interval *= 2

first_page = await anext(log_pages.pages, None)
if first_page and first_page.entries:
entry = first_page.entries[0]
logs = None
try:
logs = await anext(log_pages.pages, None)
except Exception as e:
self.log.error("Error occurred while fetching log entries: %s", e)

if logs and logs.entries:
entry = logs.entries[0]
return int(entry.json_payload["message"].split("exit status", 1)[1].strip())
self.log.info("No log pages entries found, returning None.")

self.log.info("No log entries with an exit status found yet.")
return None


Expand Down Expand Up @@ -447,13 +461,7 @@ def declare_instance(self) -> compute_v1.InstanceTemplate:
- Network configuration.
- Service account and scopes.
"""
labels = {
"team": "open-targets",
"product": "platform",
"environment": "development" if "dev" in self.project else "production",
"created_by": "unified-orchestrator",
**{k: clean_label(v) for k, v in self.labels.items()},
}
labels = prepare_labels(self.labels, self.project)

boot_disk = compute_v1.AttachedDisk(
auto_delete=True,
Expand Down Expand Up @@ -605,13 +613,15 @@ def execute_complete(self, context: Context, event: dict[str, str | list]) -> No

@cached_property
def hook(self) -> ComputeEngineHook:
"""Return the Google Compute Engine hook."""
return ComputeEngineHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

@cached_property
def logging_hook(self) -> CloudLoggingHook:
"""Return the Google Cloud Logging hook."""
return CloudLoggingHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down Expand Up @@ -653,6 +663,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.poll_sleep = poll_sleep
self.timestamp = datetime.datetime.now(datetime.timezone.utc)

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize class arguments and classpath."""
Expand Down Expand Up @@ -683,7 +694,9 @@ async def run(self):
try:
while True:
exit_code = await self.hook.get_exit_code(
self.project, self.instance_name
self.project,
self.instance_name,
self.timestamp,
)

self.log.info(f"VM {self.instance_name} exit code is {exit_code}")
Expand Down Expand Up @@ -712,6 +725,7 @@ async def run(self):

@cached_property
def hook(self) -> CloudLoggingAsyncHook:
"""Return the Google Cloud Logging async hook."""
return CloudLoggingAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down

0 comments on commit dee178a

Please sign in to comment.