Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add module loading support to pyflyte serve #3147

Merged
merged 6 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion flytekit/clis/sdk_in_container/serve.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import importlib
import os
import sys
from concurrent import futures

import grpc
Expand Down Expand Up @@ -51,13 +54,26 @@ def serve(ctx: click.Context):
help="It will wait for the specified number of seconds before shutting down grpc server. It should only be used "
"for testing.",
)
@click.option(
"--modules",
required=False,
multiple=True,
type=str,
help="List of additional files or module that defines the agent",
)
@click.pass_context
def agent(_: click.Context, port, prometheus_port, worker, timeout):
def agent(_: click.Context, port, prometheus_port, worker, timeout, modules):
"""
Start a grpc server for the agent service.
"""
import asyncio

working_dir = os.getcwd()
if all(os.path.realpath(path) != working_dir for path in sys.path):
sys.path.append(working_dir)
for m in modules:
importlib.import_module(m)

asyncio.run(_start_grpc_server(port, prometheus_port, worker, timeout))


Expand All @@ -66,6 +82,7 @@ async def _start_grpc_server(port: int, prometheus_port: int, worker: int, timeo

click.secho("🚀 Starting the agent service...")
_start_http_server(prometheus_port)

print_agents_metadata()

server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=worker))
Expand Down
6 changes: 3 additions & 3 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ def test_attr_access_sd():
execution_id = run("attr_access_sd.py", "wf", "--uri", remote_file_path)
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.fetch_execution(name=execution_id)
execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5))
execution = remote.wait(execution=execution, timeout=datetime.timedelta(minutes=15))
assert execution.error is None, f"Execution failed with error: {execution.error}"
assert execution.closure.phase == WorkflowExecutionPhase.SUCCEEDED, f"Execution failed with phase: {execution.closure.phase}"

Expand Down Expand Up @@ -1039,7 +1039,7 @@ def retry_operation(operation):
retry_operation(lambda: remote.set_input("title-input", execution.id.name, value="my report", project=PROJECT, domain=DOMAIN, python_type=str, literal_type=LiteralType(simple=SimpleType.STRING)))
retry_operation(lambda: remote.approve("review-passes", execution.id.name, project=PROJECT, domain=DOMAIN))

remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5))
remote.wait(execution=execution, timeout=datetime.timedelta(minutes=15))
assert execution.outputs["o0"] == {"title": "my report", "data": [1.0, 2.0, 3.0, 4.0, 5.0]}

with pytest.raises(FlyteAssertion, match="Outputs could not be found because the execution ended in failure"):
Expand All @@ -1048,7 +1048,7 @@ def retry_operation(operation):
retry_operation(lambda: remote.set_input("title-input", execution.id.name, value="my report", project=PROJECT, domain=DOMAIN, python_type=str, literal_type=LiteralType(simple=SimpleType.STRING)))
retry_operation(lambda: remote.reject("review-passes", execution.id.name, project=PROJECT, domain=DOMAIN))

remote.wait(execution=execution, timeout=datetime.timedelta(minutes=5))
remote.wait(execution=execution, timeout=datetime.timedelta(minutes=15))
assert execution.outputs["o0"] == {"title": "my report", "data": [1.0, 2.0, 3.0, 4.0, 5.0]}


Expand Down
Loading