diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py index 101bbe7d0fd1..0123de97f108 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py @@ -148,6 +148,7 @@ from typing_extensions import Literal, Self import prefect +from prefect.client.schemas.objects import Flow as APIFlow from prefect.exceptions import ( InfrastructureError, ) @@ -172,7 +173,6 @@ ) if TYPE_CHECKING: - from prefect.client.schemas.objects import Flow as APIFlow from prefect.client.schemas.objects import FlowRun from prefect.client.schemas.responses import DeploymentResponse from prefect.flows import Flow @@ -647,9 +647,34 @@ async def run( async def submit( self, flow: "Flow[..., R]", parameters: dict[str, Any] | None = None ) -> "FlowRun": - # TODO: Might not want to sync on every submit - await self.sync_with_backend() + """ + EXPERIMENTAL: The interface for this method is subject to change. + + Submits a flow to run in a Kubernetes job. + + Args: + flow: The flow to submit + parameters: The parameters to pass to the flow + + Returns: + A flow run object + """ + if self._runs_task_group is None: + raise RuntimeError("Worker not properly initialized") + flow_run = await self._runs_task_group.start( + self._submit_adhoc_run, flow, parameters + ) + return flow_run + async def _submit_adhoc_run( + self, + flow: "Flow[..., R]", + parameters: dict[str, Any] | None = None, + task_status: anyio.abc.TaskStatus["FlowRun"] | None = None, + ): + """ + Submits a flow run to the Kubernetes worker. + """ from prefect._experimental.bundles import ( convert_step_to_command, create_bundle_for_flow_run, @@ -660,10 +685,22 @@ async def submit( flow_run = await self._client.create_flow_run( flow, parameters=parameters, state=Pending() ) - api_flow = await self._client.read_flow(flow_run.flow_id) + if task_status is not None: + task_status.started(flow_run) + # Avoid an API call to get the flow + api_flow = APIFlow(id=flow_run.flow_id, name=flow.name, labels={}) + logger = self.get_flow_run_logger(flow_run) bundle = create_bundle_for_flow_run(flow=flow, flow_run=flow_run) + # TODO: Replace this with reading the steps from the work pool + upload_step = json.loads(os.environ.get("PREFECT__BUNDLE_UPLOAD_STEP", "{}")) + execute_step = json.loads(os.environ.get("PREFECT__BUNDLE_EXECUTE_STEP", "{}")) + + upload_command = convert_step_to_command(upload_step, str(flow_run.id)) + execute_command = convert_step_to_command(execute_step, str(flow_run.id)) + + logger.debug("Uploading execution bundle") with tempfile.TemporaryDirectory() as temp_dir: await ( anyio.Path(temp_dir) @@ -671,33 +708,27 @@ async def submit( .write_bytes(json.dumps(bundle).encode("utf-8")) ) - # TODO: Replace this with reading the steps from the work pool - upload_step = json.loads( - os.environ.get("PREFECT__BUNDLE_UPLOAD_STEP", "{}") - ) - execute_step = json.loads( - os.environ.get("PREFECT__BUNDLE_EXECUTE_STEP", "{}") - ) - - upload_command = convert_step_to_command(upload_step, str(flow_run.id)) - - await anyio.run_process( - upload_command + [str(flow_run.id)], - cwd=temp_dir, - ) - - execute_command = convert_step_to_command(execute_step, str(flow_run.id)) + try: + await anyio.run_process( + upload_command + [str(flow_run.id)], + cwd=temp_dir, + ) + except Exception as e: + self._logger.error( + "Failed to upload bundle: %s", e.stderr.decode("utf-8") + ) + raise e - configuration = await self.job_configuration.from_template_and_values( - base_job_template=self._work_pool.base_job_template, - values={"command": " ".join(execute_command)}, - client=self._client, - ) - configuration.prepare_for_flow_run(flow_run=flow_run, flow=api_flow) + logger.debug("Successfully uploaded execution bundle") - await self.run(flow_run, configuration) + configuration = await self.job_configuration.from_template_and_values( + base_job_template=self._work_pool.base_job_template, + values={"command": " ".join(execute_command)}, + client=self._client, + ) + configuration.prepare_for_flow_run(flow_run=flow_run, flow=api_flow) - return flow_run + await self.run(flow_run, configuration) async def teardown(self, *exc_info: Any): await super().teardown(*exc_info) diff --git a/src/prefect/workers/base.py b/src/prefect/workers/base.py index 5e63d7fc0947..cba74546cc77 100644 --- a/src/prefect/workers/base.py +++ b/src/prefect/workers/base.py @@ -566,8 +566,6 @@ async def start( healthcheck_thread = None try: async with self as worker: - # wait for an initial heartbeat to configure the worker - await worker.sync_with_backend() # schedule the scheduled flow run polling loop async with anyio.create_task_group() as loops_task_group: loops_task_group.start_soon( @@ -655,6 +653,8 @@ async def setup(self) -> None: await self._exit_stack.enter_async_context(self._client) await self._exit_stack.enter_async_context(self._runs_task_group) + await self.sync_with_backend() + self.is_setup = True async def teardown(self, *exc_info: Any) -> None: