Skip to content

Commit

Permalink
WIP: Makes agents work pool aware
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle committed Jan 20, 2023
1 parent 84e44fc commit 84d6ab6
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 9 deletions.
44 changes: 35 additions & 9 deletions src/prefect/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
their execution.
"""
import inspect
from typing import Iterator, List, Optional, Set, Union
from typing import AsyncIterator, List, Optional, Set, Union
from uuid import UUID

import anyio
Expand All @@ -22,7 +22,8 @@
)
from prefect.infrastructure import Infrastructure, InfrastructureResult, Process
from prefect.logging import get_logger
from prefect.orion.schemas.core import BlockDocument, FlowRun, WorkQueue
from prefect.orion import schemas
from prefect.orion.schemas.core import BlockDocument, FlowRun, WorkPoolQueue, WorkQueue
from prefect.orion.schemas.filters import (
FlowRunFilter,
FlowRunFilterId,
Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(
)

self.work_queues: Set[str] = set(work_queues) if work_queues else set()
self.work_pool_name = work_pool_name
self.prefetch_seconds = prefetch_seconds
self.submitting_flow_run_ids = set()
self.cancelling_flow_run_ids = set()
Expand Down Expand Up @@ -100,7 +102,7 @@ async def update_matched_agent_work_queues(self):
)
self.work_queues = matched_queues

async def get_work_queues(self) -> Iterator[WorkQueue]:
async def get_work_queues(self) -> AsyncIterator[Union[WorkQueue, WorkPoolQueue]]:
"""
Loads the work queue objects corresponding to the agent's target work
queues. If any of them don't exist, they are created.
Expand All @@ -122,16 +124,32 @@ async def get_work_queues(self) -> Iterator[WorkQueue]:

for name in self.work_queues:
try:
work_queue = await self.client.read_work_queue_by_name(name)
if self.work_pool_name:
work_queue = await self.client.read_work_pool_queue(
work_pool_name=self.work_pool_name, work_pool_queue_name=name
)
else:
work_queue = await self.client.read_work_queue_by_name(name)
except ObjectNotFound:

# if the work queue wasn't found, create it
if not self.work_queue_prefix:
# do not attempt to create work queues if the agent is polling for
# queues using a regex
try:
work_queue = await self.client.create_work_queue(name=name)
self.logger.info(f"Created work queue '{name}'.")
if self.work_pool_name:
work_queue = await self.client.create_work_pool_queue(
work_pool_name=self.work_pool_name,
work_pool_queue=schemas.actions.WorkPoolQueueCreate(
name=name
),
)
self.logger.info(
f"Created work queue {name!r} in work pool {self.work_pool_name!r}."
)
else:
work_queue = await self.client.create_work_queue(name=name)
self.logger.info(f"Created work queue '{name}'.")

# if creating it raises an exception, it was probably just
# created by some other agent; rather than entering a re-read
Expand Down Expand Up @@ -171,9 +189,17 @@ async def get_and_submit_flow_runs(self) -> List[FlowRun]:

else:
try:
queue_runs = await self.client.get_runs_in_work_queue(
id=work_queue.id, limit=10, scheduled_before=before
)
if isinstance(work_queue, WorkPoolQueue):
responses = await self.client.get_scheduled_flow_runs_for_work_pool_queues(
work_pool_name=self.work_pool_name,
work_pool_queue_names=[work_queue.name],
scheduled_before=before,
)
queue_runs = [response.flow_run for response in responses]
else:
queue_runs = await self.client.get_runs_in_work_queue(
id=work_queue.id, limit=10, scheduled_before=before
)
submittable_runs.extend(queue_runs)
except ObjectNotFound:
self.logger.error(
Expand Down
7 changes: 7 additions & 0 deletions src/prefect/cli/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ async def start(
"for example `dev-` will match all work queues with a name that starts with `dev-`"
),
),
work_pool_name: str = typer.Option(
None,
"-p",
"--pool",
help="A work pool name for the agent to pull from.",
),
hide_welcome: bool = typer.Option(False, "--hide-welcome"),
api: str = SettingsOption(PREFECT_API_URL),
run_once: bool = typer.Option(
Expand Down Expand Up @@ -145,6 +151,7 @@ async def start(
async with OrionAgent(
work_queues=work_queues,
work_queue_prefix=work_queue_prefix,
work_pool_name=work_pool_name,
prefetch_seconds=prefetch_seconds,
limit=limit,
) as agent:
Expand Down
19 changes: 19 additions & 0 deletions src/prefect/client/orion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2075,6 +2075,25 @@ async def read_work_pool_queues(

return pydantic.parse_obj_as(List[WorkPoolQueue], response.json())

async def read_work_pool_queue(
self, work_pool_name: str, work_pool_queue_name: str
) -> schemas.core.WorkPoolQueue:
"""
Retrieves queues for a work pool.
Args:
work_pool_name: Name of the work pool the queue belong to.
work_pool_queue_name: Name of the work pool queue to get.
Returns:
The specified work pool queue.
"""
response = await self._client.get(
f"/experimental/work_pools/{work_pool_name}/queues/{work_pool_queue_name}"
)

return pydantic.parse_obj_as(WorkPoolQueue, response.json())

async def create_work_pool_queue(
self,
work_pool_name: str,
Expand Down

0 comments on commit 84d6ab6

Please sign in to comment.