Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Makes agents work pool aware #8222

Merged
merged 10 commits into from
Jan 23, 2023
70 changes: 60 additions & 10 deletions src/prefect/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
their execution.
"""
import inspect
from typing import Iterator, List, Optional, Set, Union
from typing import AsyncIterator, List, Optional, Set, Union
from uuid import UUID

import anyio
import anyio.abc
import anyio.to_process
import pendulum

from prefect._internal.compatibility.experimental import experimental_parameter
from prefect.blocks.core import Block
from prefect.client.orion import OrionClient, get_client
from prefect.engine import propose_state
Expand All @@ -22,7 +23,8 @@
)
from prefect.infrastructure import Infrastructure, InfrastructureResult, Process
from prefect.logging import get_logger
from prefect.orion.schemas.core import BlockDocument, FlowRun, WorkQueue
from prefect.orion import schemas
from prefect.orion.schemas.core import BlockDocument, FlowRun, WorkPoolQueue, WorkQueue
from prefect.orion.schemas.filters import (
FlowRunFilter,
FlowRunFilterId,
Expand All @@ -36,10 +38,14 @@


class OrionAgent:
@experimental_parameter(
"work_pool_name", group="workers", when=lambda y: y is not None
)
def __init__(
self,
work_queues: List[str] = None,
work_queue_prefix: Union[str, List[str]] = None,
work_pool_name: str = None,
prefetch_seconds: int = None,
default_infrastructure: Infrastructure = None,
default_infrastructure_document_id: UUID = None,
Expand All @@ -52,6 +58,7 @@ def __init__(
)

self.work_queues: Set[str] = set(work_queues) if work_queues else set()
self.work_pool_name = work_pool_name
self.prefetch_seconds = prefetch_seconds
self.submitting_flow_run_ids = set()
self.cancelling_flow_run_ids = set()
Expand Down Expand Up @@ -84,7 +91,19 @@ def __init__(

async def update_matched_agent_work_queues(self):
if self.work_queue_prefix:
matched_queues = await self.client.match_work_queues(self.work_queue_prefix)
if self.work_pool_name:
matched_queues = await self.client.read_work_pool_queues(
work_pool_name=self.work_pool_name,
work_pool_queue_filter=schemas.filters.WorkPoolQueueFilter(
name=schemas.filters.WorkPoolQueueFilterName(
startswith_=self.work_queue_prefix
)
),
)
else:
matched_queues = await self.client.match_work_queues(
self.work_queue_prefix
)
matched_queues = set(q.name for q in matched_queues)
if matched_queues != self.work_queues:
new_queues = matched_queues - self.work_queues
Expand All @@ -99,7 +118,7 @@ async def update_matched_agent_work_queues(self):
)
self.work_queues = matched_queues

async def get_work_queues(self) -> Iterator[WorkQueue]:
async def get_work_queues(self) -> AsyncIterator[Union[WorkQueue, WorkPoolQueue]]:
"""
Loads the work queue objects corresponding to the agent's target work
queues. If any of them don't exist, they are created.
Expand All @@ -121,16 +140,32 @@ async def get_work_queues(self) -> Iterator[WorkQueue]:

for name in self.work_queues:
try:
work_queue = await self.client.read_work_queue_by_name(name)
if self.work_pool_name:
work_queue = await self.client.read_work_pool_queue(
work_pool_name=self.work_pool_name, work_pool_queue_name=name
)
else:
work_queue = await self.client.read_work_queue_by_name(name)
except ObjectNotFound:

# if the work queue wasn't found, create it
if not self.work_queue_prefix:
# do not attempt to create work queues if the agent is polling for
# queues using a regex
try:
work_queue = await self.client.create_work_queue(name=name)
self.logger.info(f"Created work queue '{name}'.")
if self.work_pool_name:
work_queue = await self.client.create_work_pool_queue(
work_pool_name=self.work_pool_name,
work_pool_queue=schemas.actions.WorkPoolQueueCreate(
name=name
),
)
self.logger.info(
f"Created work queue {name!r} in work pool {self.work_pool_name!r}."
)
else:
work_queue = await self.client.create_work_queue(name=name)
self.logger.info(f"Created work queue '{name}'.")

# if creating it raises an exception, it was probably just
# created by some other agent; rather than entering a re-read
Expand Down Expand Up @@ -159,6 +194,13 @@ async def get_and_submit_flow_runs(self) -> List[FlowRun]:

submittable_runs: List[FlowRun] = []

if self.work_pool_name and not self.work_queues:
responses = await self.client.get_scheduled_flow_runs_for_work_pool_queues(
work_pool_name=self.work_pool_name,
scheduled_before=before,
)
submittable_runs.extend([response.flow_run for response in responses])

# load runs from each work queue
async for work_queue in self.get_work_queues():

Expand All @@ -170,9 +212,17 @@ async def get_and_submit_flow_runs(self) -> List[FlowRun]:

else:
try:
queue_runs = await self.client.get_runs_in_work_queue(
id=work_queue.id, limit=10, scheduled_before=before
)
if isinstance(work_queue, WorkPoolQueue):
responses = await self.client.get_scheduled_flow_runs_for_work_pool_queues(
work_pool_name=self.work_pool_name,
work_pool_queue_names=[work_queue.name],
scheduled_before=before,
)
queue_runs = [response.flow_run for response in responses]
else:
queue_runs = await self.client.get_runs_in_work_queue(
id=work_queue.id, limit=10, scheduled_before=before
)
submittable_runs.extend(queue_runs)
except ObjectNotFound:
self.logger.error(
Expand Down
13 changes: 12 additions & 1 deletion src/prefect/cli/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ async def start(
"for example `dev-` will match all work queues with a name that starts with `dev-`"
),
),
work_pool_name: str = typer.Option(
None,
"-p",
"--pool",
help="A work pool name for the agent to pull from.",
),
hide_welcome: bool = typer.Option(False, "--hide-welcome"),
api: str = SettingsOption(PREFECT_API_URL),
run_once: bool = typer.Option(
Expand Down Expand Up @@ -101,13 +107,17 @@ async def start(
style="blue",
)

if not work_queues and not tags and not work_queue_prefix:
if not work_queues and not tags and not work_queue_prefix and not work_pool_name:
exit_with_error("No work queues provided!", style="red")
elif bool(work_queues) + bool(tags) + bool(work_queue_prefix) > 1:
exit_with_error(
"Only one of `work_queues`, `match`, or `tags` can be provided.",
style="red",
)
if work_pool_name and tags:
exit_with_error(
"`tag` and `pool` options cannot be used together.", style="red"
)

if tags:
work_queue_name = f"Agent queue {'-'.join(sorted(tags))}"
Expand Down Expand Up @@ -145,6 +155,7 @@ async def start(
async with OrionAgent(
work_queues=work_queues,
work_queue_prefix=work_queue_prefix,
work_pool_name=work_pool_name,
prefetch_seconds=prefetch_seconds,
limit=limit,
) as agent:
Expand Down
54 changes: 50 additions & 4 deletions src/prefect/client/orion.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
WorkPool,
WorkPoolQueue,
)
from prefect.orion.schemas.filters import FlowRunNotificationPolicyFilter, LogFilter
from prefect.orion.schemas.filters import (
FlowRunNotificationPolicyFilter,
LogFilter,
WorkPoolQueueFilter,
)
from prefect.orion.schemas.responses import WorkerFlowRunResponse
from prefect.settings import (
PREFECT_API_ENABLE_HTTP2,
Expand Down Expand Up @@ -2061,23 +2065,65 @@ async def update_work_pool(
)

async def read_work_pool_queues(
self, work_pool_name: str
self,
work_pool_name: str,
work_pool_queue_filter: Optional[WorkPoolQueueFilter] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> List[schemas.core.WorkPoolQueue]:
"""
Retrieves queues for a work pool.

Args:
work_pool_name: Name of the work pool for which to get queues.
work_pool_queue_filter: Criteria by which to filter queues.
limit: Limit for the queue query.
offset: Limit for the queue query.

Returns:
List of queues for the specified work pool.
"""
response = await self._client.get(
f"/experimental/work_pools/{work_pool_name}/queues"
json = {
"flow_run_notification_policy_filter": work_pool_queue_filter.dict(
json_compatible=True, exclude_unset=True
)
if work_pool_queue_filter
else None,
"limit": limit,
"offset": offset,
}

response = await self._client.post(
f"/experimental/work_pools/{work_pool_name}/queues/filter",
json=json,
)

return pydantic.parse_obj_as(List[WorkPoolQueue], response.json())

async def read_work_pool_queue(
self, work_pool_name: str, work_pool_queue_name: str
) -> schemas.core.WorkPoolQueue:
"""
Retrieves a given queue for a work pool.

Args:
work_pool_name: Name of the work pool the queue belong to.
work_pool_queue_name: Name of the work pool queue to get.

Returns:
The specified work pool queue.
"""
try:
response = await self._client.get(
f"/experimental/work_pools/{work_pool_name}/queues/{work_pool_queue_name}"
)
return pydantic.parse_obj_as(WorkPoolQueue, response.json())
except httpx.HTTPStatusError as e:
if e.response.status_code == status.HTTP_404_NOT_FOUND:
raise prefect.exceptions.ObjectNotFound(http_exc=e) from e
else:
raise

async def create_work_pool_queue(
self,
work_pool_name: str,
Expand Down
12 changes: 10 additions & 2 deletions src/prefect/orion/api/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,12 @@ async def read_work_pool_queue(
)


@router.get("/{work_pool_name}/queues")
@router.post("/{work_pool_name}/queues/filter")
async def read_work_pool_queues(
work_pool_name: str = Path(..., description="The work pool name"),
work_pool_queues: schemas.filters.WorkPoolQueueFilter = None,
limit: int = dependencies.LimitBody(),
offset: int = Body(0, ge=0),
worker_lookups: WorkerLookups = Depends(WorkerLookups),
db: OrionDBInterface = Depends(provide_database_interface),
) -> List[schemas.core.WorkPoolQueue]:
Expand All @@ -388,7 +391,12 @@ async def read_work_pool_queues(
work_pool_name=work_pool_name,
)
return await models.workers.read_work_pool_queues(
session=session, work_pool_id=work_pool_id, db=db
session=session,
work_pool_id=work_pool_id,
work_pool_queue_filter=work_pool_queues,
limit=limit,
offset=offset,
db=db,
)


Expand Down
17 changes: 16 additions & 1 deletion src/prefect/orion/models/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Intended for internal use by the Orion API.
"""
import datetime
from typing import Dict, List
from typing import Dict, List, Optional
from uuid import UUID

import pendulum
Expand Down Expand Up @@ -333,13 +333,20 @@ async def read_work_pool_queues(
session: AsyncSession,
work_pool_id: UUID,
db: OrionDBInterface,
work_pool_queue_filter: Optional[schemas.filters.WorkPoolQueueFilter] = None,
offset: Optional[int] = None,
limit: Optional[int] = None,
) -> List[ORMWorkPoolQueue]:
"""
Read all work pool queues for a work pool. Results are ordered by ascending priority.

Args:
session (AsyncSession): a database session
work_pool_id (UUID): a work pool id
work_pool_queue_filter: Filter criteria for work pool queues
offset: Query offset
limit: Query limit


Returns:
List[db.WorkPoolQueue]: the WorkPoolQueues
Expand All @@ -350,6 +357,14 @@ async def read_work_pool_queues(
.where(db.WorkPoolQueue.work_pool_id == work_pool_id)
.order_by(db.WorkPoolQueue.priority.asc())
)

if work_pool_queue_filter is not None:
query = query.where(work_pool_queue_filter.as_sql_filter(db))
if offset is not None:
query = query.offset(offset)
if limit is not None:
query = query.limit(limit)

result = await session.execute(query)
return result.scalars().unique().all()

Expand Down
18 changes: 18 additions & 0 deletions src/prefect/orion/schemas/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,11 +1372,29 @@ class WorkPoolQueueFilterName(PrefectFilterBaseModel):
any_: Optional[List[str]] = Field(
default=None, description="A list of work pool queue names to include"
)
startswith_: Optional[List[str]] = Field(
default=None,
description=(
"A list of case-insensitive starts-with matches. For example, "
" passing 'marvin' will match "
"'marvin', and 'Marvin-robot', but not 'sad-marvin'."
),
example=["marvin", "Marvin-robot"],
)

def _get_filter_list(self, db: "OrionDBInterface") -> List:
filters = []
if self.any_ is not None:
filters.append(db.WorkPoolQueue.name.in_(self.any_))
if self.startswith_ is not None:
filters.append(
sa.or_(
*[
db.WorkPoolQueue.name.ilike(f"{item}%")
for item in self.startswith_
]
)
)
return filters


Expand Down
Loading