Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adapt fal client to proto changes in max_concurrency setting #6

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 8 additions & 12 deletions projects/fal/src/fal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ class Host(Generic[ArgsT, ReturnT]):
is executed."""

_SUPPORTED_KEYS: ClassVar[frozenset[str]] = frozenset()
_GATEWAY_KEYS: ClassVar[frozenset[str]] = frozenset(
{"serve", "exposed_port", "max_concurrency"}
)
_GATEWAY_KEYS: ClassVar[frozenset[str]] = frozenset({"serve", "exposed_port"})

def __post_init__(self):
assert not self._SUPPORTED_KEYS.intersection(
Expand Down Expand Up @@ -118,7 +116,6 @@ def register(
self,
func: Callable[ArgsT, ReturnT],
options: Options,
max_concurrency: int | None = None,
application_name: str | None = None,
application_auth_mode: Literal["public", "shared", "private"] | None = None,
metadata: dict[str, Any] | None = None,
Expand Down Expand Up @@ -311,6 +308,7 @@ class FalServerlessHost(Host):
{
"machine_type",
"keep_alive",
"max_concurrency",
"max_multiplexing",
"setup_function",
"metadata",
Expand Down Expand Up @@ -341,7 +339,6 @@ def register(
self,
func: Callable[ArgsT, ReturnT],
options: Options,
max_concurrency: int | None = None,
application_name: str | None = None,
application_auth_mode: Literal["public", "shared", "private"] | None = None,
metadata: dict[str, Any] | None = None,
Expand All @@ -354,9 +351,8 @@ def register(
"machine_type", FAL_SERVERLESS_DEFAULT_MACHINE_TYPE
)
keep_alive = options.host.get("keep_alive", FAL_SERVERLESS_DEFAULT_KEEP_ALIVE)
max_multiplexing = options.host.get(
"max_multiplexing", FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING
)
max_concurrency = options.host.get("max_concurrency")
max_multiplexing = options.host.get("max_multiplexing")
base_image = options.host.get("_base_image", None)
scheduler = options.host.get("_scheduler", None)
scheduler_options = options.host.get("_scheduler_options", None)
Expand All @@ -370,6 +366,7 @@ def register(
scheduler=scheduler,
scheduler_options=scheduler_options,
max_multiplexing=max_multiplexing,
max_concurrency=max_concurrency,
)

partial_func = _prepare_partial_func(func)
Expand All @@ -394,7 +391,6 @@ def register(
application_name=application_name,
application_auth_mode=application_auth_mode,
machine_requirements=machine_requirements,
max_concurrency=max_concurrency,
metadata=metadata,
):
for log in partial_result.logs:
Expand All @@ -419,9 +415,8 @@ def run(
"machine_type", FAL_SERVERLESS_DEFAULT_MACHINE_TYPE
)
keep_alive = options.host.get("keep_alive", FAL_SERVERLESS_DEFAULT_KEEP_ALIVE)
max_multiplexing = options.host.get(
"max_multiplexing", FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING
)
max_concurrency = options.host.get("max_concurrency")
max_multiplexing = options.host.get("max_multiplexing")
base_image = options.host.get("_base_image", None)
scheduler = options.host.get("_scheduler", None)
scheduler_options = options.host.get("_scheduler_options", None)
Expand All @@ -436,6 +431,7 @@ def run(
scheduler=scheduler,
scheduler_options=scheduler_options,
max_multiplexing=max_multiplexing,
max_concurrency=max_concurrency,
)

return_value = _UNSET
Expand Down
76 changes: 47 additions & 29 deletions projects/fal/src/fal/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from fal.logging.isolate import IsolateLogPrinter
from fal.logging.trace import get_tracer
from fal.rest_client import REST_CLIENT
from fal.sdk import KeyScope
from fal.sdk import AliasInfo, KeyScope
from isolate.logs import Log, LogLevel, LogSource
from rich.table import Table

Expand Down Expand Up @@ -280,13 +280,11 @@ def register_application(
"Must expose port 8080 for now. This will be configurable in the future."
)

max_concurrency = gateway_options.get("max_concurrency")
id = host.register(
func=isolated_function.func,
options=isolated_function.options,
application_name=alias,
application_auth_mode=auth_mode,
max_concurrency=max_concurrency,
metadata={},
)

Expand Down Expand Up @@ -342,57 +340,77 @@ def alias_cli(ctx, host: str, port: str):
ctx.obj = api.FalServerlessClient(f"{host}:{port}")


def _alias_table(aliases: list[AliasInfo]):
table = Table(title="Function Aliases")
table.add_column("Alias")
table.add_column("Revision")
table.add_column("Auth")
table.add_column("Max Concurrency")
table.add_column("Max Multiplexing")
table.add_column("Keep Alive")

for app_alias in aliases:
table.add_row(
app_alias.alias,
app_alias.revision,
app_alias.auth_mode,
str(app_alias.max_concurrency),
str(app_alias.max_multiplexing),
str(app_alias.keep_alive),
)

return table


@alias_cli.command("list")
@click.pass_obj
def alias_list(client: api.FalServerlessClient):
with client.connect() as connection:
table = Table(title="Function Aliases")
table.add_column("Alias")
table.add_column("Revision")
table.add_column("Auth")
table.add_column("Max Concurrency")

for app_alias in connection.list_aliases():
table.add_row(
app_alias.alias,
app_alias.revision,
app_alias.auth_mode,
str(app_alias.max_concurrency),
)
aliases = connection.list_aliases()
table = _alias_table(aliases)

console.print(table)


@alias_cli.command("scale")
@click.argument("alias", required=True)
@click.argument("max_concurrency", required=True, type=int)
@click.pass_obj
def alias_scale(client: api.FalServerlessClient, alias: str, max_concurrency: int):
with client.connect() as connection:
connection.scale(application_name=alias, max_concurrency=max_concurrency)


@alias_cli.command("update")
@click.argument("alias", required=True)
@click.option("--keep-alive", type=int)
@click.option("--max-multiplexing", type=int)
@click.option("--keep-alive", "-k", type=int)
@click.option("--max-multiplexing", "-m", type=int)
@click.option("--max-concurrency", "-c", type=int)
@click.pass_obj
def alias_update(
client: api.FalServerlessClient,
alias: str,
keep_alive: int | None,
max_multiplexing: int | None,
max_concurrency: int | None,
):
with client.connect() as connection:
if not (keep_alive or max_multiplexing):
if keep_alive is None and max_multiplexing is None and max_concurrency is None:
console.log("No parameters for update were provided, ignoring.")
return

connection.update_application(
alias_info = connection.update_application(
application_name=alias,
keep_alive=keep_alive,
max_multiplexing=max_multiplexing,
max_concurrency=max_concurrency,
)
table = _alias_table([alias_info])

console.print(table)


@alias_cli.command("scale")
@click.argument("alias", required=True)
@click.argument("max_concurrency", required=True, type=int)
def alias_scale(alias: str, max_concurrency: int):
alias_update.callback(
alias=alias,
keep_alive=None,
max_multiplexing=None,
max_concurrency=max_concurrency,
) # type: ignore


##### Secrets group #####
Expand Down
26 changes: 16 additions & 10 deletions projects/fal/src/fal/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ class AliasInfo:
alias: str
revision: str
auth_mode: str
keep_alive: int
max_concurrency: int
max_multiplexing: int


@dataclass
Expand Down Expand Up @@ -258,7 +260,9 @@ def _from_grpc_alias_info(message: isolate_proto.AliasInfo) -> AliasInfo:
alias=message.alias,
revision=message.revision,
auth_mode=auth_mode,
keep_alive=message.keep_alive,
max_concurrency=message.max_concurrency,
max_multiplexing=message.max_multiplexing,
)


Expand Down Expand Up @@ -306,7 +310,8 @@ class MachineRequirements:
exposed_port: int | None = None
scheduler: str | None = None
scheduler_options: dict[str, Any] | None = None
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING
max_concurrency: int | None = None
max_multiplexing: int | None = None


@dataclass
Expand Down Expand Up @@ -386,7 +391,6 @@ def register(
application_name: str | None = None,
application_auth_mode: Literal["public", "private", "shared"] | None = None,
*,
max_concurrency: int | None = None,
serialization_method: str = _DEFAULT_SERIALIZATION_METHOD,
machine_requirements: MachineRequirements | None = None,
metadata: dict[str, Any] | None = None,
Expand All @@ -402,6 +406,7 @@ def register(
scheduler_options=to_struct(
machine_requirements.scheduler_options or {}
),
max_concurrency=machine_requirements.max_concurrency,
max_multiplexing=machine_requirements.max_multiplexing,
)
else:
Expand All @@ -423,7 +428,6 @@ def register(
function=wrapped_function,
environments=environments,
machine_requirements=wrapped_requirements,
max_concurrency=max_concurrency,
application_name=application_name,
auth_mode=auth_mode,
metadata=struct_metadata,
Expand All @@ -432,24 +436,25 @@ def register(
yield from_grpc(partial_result)

def scale(self, application_name: str, max_concurrency: int | None = None) -> None:
request = isolate_proto.ScaleApplicationRequest(
application_name=application_name,
max_concurrency=max_concurrency,
)
self.stub.ScaleApplication(request)
raise NotImplementedError

def update_application(
self,
application_name: str,
keep_alive: int | None = None,
max_multiplexing: int | None = None,
) -> None:
max_concurrency: int | None = None,
) -> AliasInfo:
request = isolate_proto.UpdateApplicationRequest(
application_name=application_name,
keep_alive=keep_alive,
max_multiplexing=max_multiplexing,
max_concurrency=max_concurrency,
)
res: isolate_proto.UpdateApplicationResult = self.stub.UpdateApplication(
request
)
self.stub.UpdateApplication(request)
return from_grpc(res.alias_info)

def run(
self,
Expand All @@ -471,6 +476,7 @@ def run(
scheduler_options=to_struct(
machine_requirements.scheduler_options or {}
),
max_concurrency=machine_requirements.max_concurrency,
max_multiplexing=machine_requirements.max_multiplexing,
)
else:
Expand Down
2 changes: 1 addition & 1 deletion projects/fal/tests/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Output(BaseModel):
keep_alive=60,
machine_type="S",
serve=True,
max_concurrency=1,
)
def addition_app(input: Input) -> Output:
print("starting...")
Expand All @@ -41,7 +42,6 @@ def addition_app(input: Input) -> Output:
app_alias = addition_app.host.register(
func=addition_app.func,
options=addition_app.options,
max_concurrency=1,
)
user_id = _get_user_id()
yield f"{user_id}-{app_alias}"
Expand Down