Skip to content

Commit

Permalink
Run submission via the worker's task group
Browse files Browse the repository at this point in the history
  • Loading branch information
desertaxle committed Feb 21, 2025
1 parent cea7de8 commit 22fb194
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 30 deletions.
87 changes: 59 additions & 28 deletions src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@
from typing_extensions import Literal, Self

import prefect
from prefect.client.schemas.objects import Flow as APIFlow
from prefect.exceptions import (
InfrastructureError,
)
Expand All @@ -172,7 +173,6 @@
)

if TYPE_CHECKING:
from prefect.client.schemas.objects import Flow as APIFlow
from prefect.client.schemas.objects import FlowRun
from prefect.client.schemas.responses import DeploymentResponse
from prefect.flows import Flow
Expand Down Expand Up @@ -647,9 +647,34 @@ async def run(
async def submit(
self, flow: "Flow[..., R]", parameters: dict[str, Any] | None = None
) -> "FlowRun":
# TODO: Might not want to sync on every submit
await self.sync_with_backend()
"""
EXPERIMENTAL: The interface for this method is subject to change.
Submits a flow to run in a Kubernetes job.
Args:
flow: The flow to submit
parameters: The parameters to pass to the flow
Returns:
A flow run object
"""
if self._runs_task_group is None:
raise RuntimeError("Worker not properly initialized")
flow_run = await self._runs_task_group.start(
self._submit_adhoc_run, flow, parameters
)
return flow_run

async def _submit_adhoc_run(
self,
flow: "Flow[..., R]",
parameters: dict[str, Any] | None = None,
task_status: anyio.abc.TaskStatus["FlowRun"] | None = None,
):
"""
Submits a flow run to the Kubernetes worker.
"""
from prefect._experimental.bundles import (
convert_step_to_command,
create_bundle_for_flow_run,
Expand All @@ -660,44 +685,50 @@ async def submit(
flow_run = await self._client.create_flow_run(
flow, parameters=parameters, state=Pending()
)
api_flow = await self._client.read_flow(flow_run.flow_id)
if task_status is not None:
task_status.started(flow_run)
# Avoid an API call to get the flow
api_flow = APIFlow(id=flow_run.flow_id, name=flow.name, labels={})
logger = self.get_flow_run_logger(flow_run)

bundle = create_bundle_for_flow_run(flow=flow, flow_run=flow_run)

# TODO: Replace this with reading the steps from the work pool
upload_step = json.loads(os.environ.get("PREFECT__BUNDLE_UPLOAD_STEP", "{}"))
execute_step = json.loads(os.environ.get("PREFECT__BUNDLE_EXECUTE_STEP", "{}"))

upload_command = convert_step_to_command(upload_step, str(flow_run.id))
execute_command = convert_step_to_command(execute_step, str(flow_run.id))

logger.debug("Uploading execution bundle")
with tempfile.TemporaryDirectory() as temp_dir:
await (
anyio.Path(temp_dir)
.joinpath(str(flow_run.id))
.write_bytes(json.dumps(bundle).encode("utf-8"))
)

# TODO: Replace this with reading the steps from the work pool
upload_step = json.loads(
os.environ.get("PREFECT__BUNDLE_UPLOAD_STEP", "{}")
)
execute_step = json.loads(
os.environ.get("PREFECT__BUNDLE_EXECUTE_STEP", "{}")
)

upload_command = convert_step_to_command(upload_step, str(flow_run.id))

await anyio.run_process(
upload_command + [str(flow_run.id)],
cwd=temp_dir,
)

execute_command = convert_step_to_command(execute_step, str(flow_run.id))
try:
await anyio.run_process(
upload_command + [str(flow_run.id)],
cwd=temp_dir,
)
except Exception as e:
self._logger.error(
"Failed to upload bundle: %s", e.stderr.decode("utf-8")
)
raise e

configuration = await self.job_configuration.from_template_and_values(
base_job_template=self._work_pool.base_job_template,
values={"command": " ".join(execute_command)},
client=self._client,
)
configuration.prepare_for_flow_run(flow_run=flow_run, flow=api_flow)
logger.debug("Successfully uploaded execution bundle")

await self.run(flow_run, configuration)
configuration = await self.job_configuration.from_template_and_values(
base_job_template=self._work_pool.base_job_template,
values={"command": " ".join(execute_command)},
client=self._client,
)
configuration.prepare_for_flow_run(flow_run=flow_run, flow=api_flow)

return flow_run
await self.run(flow_run, configuration)

async def teardown(self, *exc_info: Any):
await super().teardown(*exc_info)
Expand Down
4 changes: 2 additions & 2 deletions src/prefect/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,6 @@ async def start(
healthcheck_thread = None
try:
async with self as worker:
# wait for an initial heartbeat to configure the worker
await worker.sync_with_backend()
# schedule the scheduled flow run polling loop
async with anyio.create_task_group() as loops_task_group:
loops_task_group.start_soon(
Expand Down Expand Up @@ -655,6 +653,8 @@ async def setup(self) -> None:
await self._exit_stack.enter_async_context(self._client)
await self._exit_stack.enter_async_context(self._runs_task_group)

await self.sync_with_backend()

self.is_setup = True

async def teardown(self, *exc_info: Any) -> None:
Expand Down

0 comments on commit 22fb194

Please sign in to comment.