Skip to content

Commit

Permalink
Add BaseWorker and ProcessWorker (#7996)
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle authored and masonmenges committed Jan 10, 2023
1 parent fb8bd90 commit 2867e90
Show file tree
Hide file tree
Showing 14 changed files with 1,678 additions and 8 deletions.
5 changes: 5 additions & 0 deletions src/prefect/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from prefect.cli.root import app
import prefect.settings

# Import CLI submodules to register them to the app
# isort: split
Expand All @@ -17,3 +18,7 @@
import prefect.cli.orion_utils
import prefect.cli.profile
import prefect.cli.work_queue

# Only load workers CLI if enabled via a setting
if prefect.settings.PREFECT_EXPERIMENTAL_ENABLE_WORKERS.value():
import prefect.experimental.cli.worker
188 changes: 188 additions & 0 deletions src/prefect/client/orion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@
BlockType,
FlowRunNotificationPolicy,
QueueFilter,
WorkerPool,
WorkerPoolQueue,
)
from prefect.orion.schemas.filters import FlowRunNotificationPolicyFilter, LogFilter
from prefect.orion.schemas.responses import WorkerFlowRunResponse
from prefect.settings import (
PREFECT_API_ENABLE_HTTP2,
PREFECT_API_KEY,
Expand Down Expand Up @@ -1263,6 +1266,8 @@ async def create_deployment(
parameters: Dict[str, Any] = None,
description: str = None,
work_queue_name: str = None,
worker_pool_name: str = None,
worker_pool_queue_name: str = None,
tags: List[str] = None,
storage_document_id: UUID = None,
manifest_path: str = None,
Expand Down Expand Up @@ -1300,6 +1305,8 @@ async def create_deployment(
parameters=dict(parameters or {}),
tags=list(tags or []),
work_queue_name=work_queue_name,
worker_pool_name=worker_pool_name,
worker_pool_queue_name=worker_pool_queue_name,
description=description,
storage_document_id=storage_document_id,
path=path,
Expand Down Expand Up @@ -1942,6 +1949,187 @@ async def resolve_inner(data):

return await resolve_inner(datadoc)

async def send_worker_heartbeat(self, worker_pool_name: str, worker_name: str):
"""
Sends a worker heartbeat for a given worker pool.
Args:
worker_pool_name: The name of the worker pool to heartbeat against.
worker_name: The name of the worker sending the heartbeat.
"""
await self._client.post(
f"/experimental/worker_pools/{worker_pool_name}/workers/heartbeat",
json={"name": worker_name},
)

async def read_workers_for_worker_pool(
self,
worker_pool_name: str,
worker_filter: Optional[schemas.filters.WorkerFilter] = None,
offset: Optional[int] = None,
limit: Optional[int] = None,
) -> List[schemas.core.Worker]:
"""
Reads workers for a given worker pool.
Args:
worker_pool_name: The name of the worker pool for which to get
member workers.
worker_filter: Criteria by which to filter workers.
limit: Limit for the worker query.
offset: Limit for the worker query.
"""
response = await self._client.post(
f"/experimental/worker_pools/{worker_pool_name}/workers/filter",
json={
"worker_filter": (
worker_filter.dict(json_compatible=True, exclude_unset=True)
if worker_filter
else None
),
"offset": offset,
"limit": limit,
},
)

return pydantic.parse_obj_as(List[schemas.core.Worker], response.json())

async def read_worker_pool(self, worker_pool_name: str) -> schemas.core.WorkerPool:
"""
Reads information for a given worker pool
Args:
worker_pool_name: The name of the worker pool to for which to get
information.
Returns:
Information about the requested worker pool.
"""
try:
response = await self._client.get(
f"/experimental/worker_pools/{worker_pool_name}"
)
return pydantic.parse_obj_as(WorkerPool, response.json())
except httpx.HTTPStatusError as e:
if e.response.status_code == status.HTTP_404_NOT_FOUND:
raise prefect.exceptions.ObjectNotFound(http_exc=e) from e
else:
raise

async def create_worker_pool(
self,
worker_pool: schemas.actions.WorkerPoolCreate,
) -> schemas.core.WorkerPool:
"""
Creates a worker pool with the provided configuration.
Args:
worker_pool: Desired configuration for the new worker pool.
Returns:
Information about the newly created worker pool.
"""
response = await self._client.post(
"/experimental/worker_pools/",
json=worker_pool.dict(json_compatible=True, exclude_unset=True),
)

return pydantic.parse_obj_as(WorkerPool, response.json())

async def read_worker_pool_queues(
self, worker_pool_name: str
) -> List[schemas.core.WorkerPoolQueue]:
"""
Retrieves queues for a worker pool.
Args:
worker_pool_name: Name of the worker pool for which to get queues.
Returns:
List of queues for the specified worker pool.
"""
response = await self._client.get(
f"/experimental/worker_pools/{worker_pool_name}/queues"
)

return pydantic.parse_obj_as(List[WorkerPoolQueue], response.json())

async def create_worker_pool_queue(
self,
worker_pool_name: str,
worker_pool_queue: schemas.actions.WorkerPoolQueueCreate,
) -> schemas.core.WorkerPoolQueue:
"""
Creates a queue for a given worker pool
Args:
worker_pool_name: Name of the worker pool to create the queue under.
worker_pool_queue: Desired configuration for the new queue.
Returns:
Information about the newly created queue.
"""
response = await self._client.post(
f"/experimental/worker_pools/{worker_pool_name}/queues",
json=worker_pool_queue.dict(json_compatible=True, exclude_unset=True),
)

return pydantic.parse_obj_as(WorkerPoolQueue, response.json())

async def update_worker_pool_queue(
self,
worker_pool_name: str,
worker_pool_queue_name: str,
worker_pool_queue: schemas.actions.WorkerPoolQueueUpdate,
):
"""
Creates a queue for a given worker pool
Args:
worker_pool_name: Name of the worker pool in which the queue resides.
worker_pool_queue_name: Name of the worker pool queue to update
worker_pool_queue: Desired updates for the queue.
"""
await self._client.patch(
f"/experimental/worker_pools/{worker_pool_name}/queues/{worker_pool_queue_name}",
json=worker_pool_queue.dict(json_compatible=True, exclude_unset=True),
)

async def get_scheduled_flow_runs_for_worker_pool_queues(
self,
worker_pool_name: str,
worker_pool_queue_names: Optional[List[str]] = None,
scheduled_before: Optional[datetime.datetime] = None,
) -> List[WorkerFlowRunResponse]:
"""
Retrieves scheduled flow runs for the provided set of worker pool queues.
Args:
worker_pool_name: The name of the worker pool that the worker pool
queues are associated with.
worker_pool_queue_names: The names of the worker pool queues from which
to get scheduled flow runs.
scheduled_before: Datetime used to filter returned flow runs. Flow runs
scheduled for after the given datetime string will not be returned.
Returns:
A list of worker flow run responses containing information about the
retrieved flow runs.
"""
body: Dict[str, Any] = {}
if worker_pool_queue_names is not None:
body["worker_pool_queue_names"] = worker_pool_queue_names
if scheduled_before:
body["scheduled_before"] = str(scheduled_before)

response = await self._client.post(
f"/experimental/worker_pools/{worker_pool_name}/get_scheduled_flow_runs",
json=body,
)

return pydantic.parse_obj_as(List[WorkerFlowRunResponse], response.json())

async def __aenter__(self):
"""
Start the client.
Expand Down
23 changes: 21 additions & 2 deletions src/prefect/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import sys
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
from uuid import UUID
Expand All @@ -15,6 +16,7 @@
import yaml
from pydantic import BaseModel, Field, parse_obj_as, validator

from prefect._internal.compatibility.experimental import experimental_field
from prefect.blocks.core import Block
from prefect.blocks.fields import SecretDict
from prefect.client.orion import OrionClient, get_client
Expand Down Expand Up @@ -206,6 +208,13 @@ def load_deployments_from_yaml(
return registry


@experimental_field("worker_pool_name", group="workers", when=lambda x: x is not None)
@experimental_field(
"worker_pool_queue_name",
group="workers",
when=lambda x: x is not None,
stacklevel=4,
)
class Deployment(BaseModel):
"""
A Prefect Deployment definition, used for specifying and building deployments.
Expand Down Expand Up @@ -277,6 +286,8 @@ def _editable_fields(self) -> List[str]:
"description",
"version",
"work_queue_name",
"worker_pool_name",
"worker_pool_queue_name",
"tags",
"parameters",
"schedule",
Expand Down Expand Up @@ -377,7 +388,12 @@ def _yaml_dict(self) -> dict:
description="The work queue for the deployment.",
yaml_comment="The work queue that will handle this deployment's runs",
)

worker_pool_name: Optional[str] = Field(
default=None, description="The worker pool for the deployment"
)
worker_pool_queue_name: Optional[str] = Field(
default=None, description="The worker pool queue for the deployment."
)
# flow data
parameters: Dict[str, Any] = Field(default_factory=dict)
manifest_path: Optional[str] = Field(
Expand Down Expand Up @@ -405,6 +421,7 @@ def _yaml_dict(self) -> dict:
default_factory=ParameterSchema,
description="The parameter schema of the flow, including defaults.",
)
timestamp: datetime = Field(default_factory=partial(pendulum.now, "UTC"))

@validator("infrastructure", pre=True)
def infrastructure_must_have_capabilities(cls, value):
Expand Down Expand Up @@ -499,7 +516,7 @@ async def load(self) -> bool:
)

excluded_fields = self.__fields_set__.union(
{"infrastructure", "storage"}
{"infrastructure", "storage", "timestamp"}
)
for field in set(self.__fields__.keys()) - excluded_fields:
new_value = getattr(deployment, field)
Expand Down Expand Up @@ -634,6 +651,8 @@ async def apply(
flow_id=flow_id,
name=self.name,
work_queue_name=self.work_queue_name,
worker_pool_name=self.worker_pool_name,
worker_pool_queue_name=self.worker_pool_queue_name,
version=self.version,
schedule=self.schedule,
parameters=self.parameters,
Expand Down
Loading

0 comments on commit 2867e90

Please sign in to comment.