diff --git a/flytekit/clis/sdk_in_container/serve.py b/flytekit/clis/sdk_in_container/serve.py index 3d0d880e78..a49f424aa9 100644 --- a/flytekit/clis/sdk_in_container/serve.py +++ b/flytekit/clis/sdk_in_container/serve.py @@ -1,3 +1,6 @@ +import importlib +import os +import sys from concurrent import futures import grpc @@ -51,13 +54,26 @@ def serve(ctx: click.Context): help="It will wait for the specified number of seconds before shutting down grpc server. It should only be used " "for testing.", ) +@click.option( + "--modules", + required=False, + multiple=True, + type=str, + help="List of additional files or module that defines the agent", +) @click.pass_context -def agent(_: click.Context, port, prometheus_port, worker, timeout): +def agent(_: click.Context, port, prometheus_port, worker, timeout, modules): """ Start a grpc server for the agent service. """ import asyncio + working_dir = os.getcwd() + if all(os.path.realpath(path) != working_dir for path in sys.path): + sys.path.append(working_dir) + for m in modules: + importlib.import_module(m) + asyncio.run(_start_grpc_server(port, prometheus_port, worker, timeout)) @@ -66,6 +82,7 @@ async def _start_grpc_server(port: int, prometheus_port: int, worker: int, timeo click.secho("🚀 Starting the agent service...") _start_http_server(prometheus_port) + print_agents_metadata() server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=worker)) diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index b1f463929a..74015673bb 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -965,7 +965,7 @@ def test_attr_access_sd(): execution_id = run("attr_access_sd.py", "wf", "--uri", remote_file_path) remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) execution = remote.fetch_execution(name=execution_id) - execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5)) + execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=15)) assert execution.error is None, f"Execution failed with error: {execution.error}" assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}" @@ -1039,7 +1039,7 @@ def retry_operation(operation): retry_operation(lambda: remote.set_input("title-input", execution.id.name, value="my report", project=PROJECT, domain=DOMAIN, python_type=str, literal_type=LiteralType(simple=SimpleType.STRING))) retry_operation(lambda: remote.approve("review-passes", execution.id.name, project=PROJECT, domain=DOMAIN)) - remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5)) + remote.wait(execution=execution, timeout=datetime.timedelta(minutes=15)) assert execution.outputs["o0"] == {"title": "my report", "data": [1.0, 2.0, 3.0, 4.0, 5.0]} with pytest.raises(FlyteAssertion, match="Outputs could not be found because the execution ended in failure"): @@ -1048,7 +1048,7 @@ def retry_operation(operation): retry_operation(lambda: remote.set_input("title-input", execution.id.name, value="my report", project=PROJECT, domain=DOMAIN, python_type=str, literal_type=LiteralType(simple=SimpleType.STRING))) retry_operation(lambda: remote.reject("review-passes", execution.id.name, project=PROJECT, domain=DOMAIN)) - remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5)) + remote.wait(execution=execution, timeout=datetime.timedelta(minutes=15)) assert execution.outputs["o0"] == {"title": "my report", "data": [1.0, 2.0, 3.0, 4.0, 5.0]}