Skip to content

Commit

Permalink
feat: adapt fal client to proto changes in max_concurrency setting (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
chamini2 authored Dec 13, 2023
1 parent 5911ddb commit 395d211
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 52 deletions.
20 changes: 8 additions & 12 deletions projects/fal/src/fal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ class Host(Generic[ArgsT, ReturnT]):
is executed."""

_SUPPORTED_KEYS: ClassVar[frozenset[str]] = frozenset()
_GATEWAY_KEYS: ClassVar[frozenset[str]] = frozenset(
{"serve", "exposed_port", "max_concurrency"}
)
_GATEWAY_KEYS: ClassVar[frozenset[str]] = frozenset({"serve", "exposed_port"})

def __post_init__(self):
assert not self._SUPPORTED_KEYS.intersection(
Expand Down Expand Up @@ -118,7 +116,6 @@ def register(
self,
func: Callable[ArgsT, ReturnT],
options: Options,
max_concurrency: int | None = None,
application_name: str | None = None,
application_auth_mode: Literal["public", "shared", "private"] | None = None,
metadata: dict[str, Any] | None = None,
Expand Down Expand Up @@ -311,6 +308,7 @@ class FalServerlessHost(Host):
{
"machine_type",
"keep_alive",
"max_concurrency",
"max_multiplexing",
"setup_function",
"metadata",
Expand Down Expand Up @@ -341,7 +339,6 @@ def register(
self,
func: Callable[ArgsT, ReturnT],
options: Options,
max_concurrency: int | None = None,
application_name: str | None = None,
application_auth_mode: Literal["public", "shared", "private"] | None = None,
metadata: dict[str, Any] | None = None,
Expand All @@ -354,9 +351,8 @@ def register(
"machine_type", FAL_SERVERLESS_DEFAULT_MACHINE_TYPE
)
keep_alive = options.host.get("keep_alive", FAL_SERVERLESS_DEFAULT_KEEP_ALIVE)
max_multiplexing = options.host.get(
"max_multiplexing", FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING
)
max_concurrency = options.host.get("max_concurrency")
max_multiplexing = options.host.get("max_multiplexing")
base_image = options.host.get("_base_image", None)
scheduler = options.host.get("_scheduler", None)
scheduler_options = options.host.get("_scheduler_options", None)
Expand All @@ -370,6 +366,7 @@ def register(
scheduler=scheduler,
scheduler_options=scheduler_options,
max_multiplexing=max_multiplexing,
max_concurrency=max_concurrency,
)

partial_func = _prepare_partial_func(func)
Expand All @@ -394,7 +391,6 @@ def register(
application_name=application_name,
application_auth_mode=application_auth_mode,
machine_requirements=machine_requirements,
max_concurrency=max_concurrency,
metadata=metadata,
):
for log in partial_result.logs:
Expand All @@ -419,9 +415,8 @@ def run(
"machine_type", FAL_SERVERLESS_DEFAULT_MACHINE_TYPE
)
keep_alive = options.host.get("keep_alive", FAL_SERVERLESS_DEFAULT_KEEP_ALIVE)
max_multiplexing = options.host.get(
"max_multiplexing", FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING
)
max_concurrency = options.host.get("max_concurrency")
max_multiplexing = options.host.get("max_multiplexing")
base_image = options.host.get("_base_image", None)
scheduler = options.host.get("_scheduler", None)
scheduler_options = options.host.get("_scheduler_options", None)
Expand All @@ -436,6 +431,7 @@ def run(
scheduler=scheduler,
scheduler_options=scheduler_options,
max_multiplexing=max_multiplexing,
max_concurrency=max_concurrency,
)

return_value = _UNSET
Expand Down
76 changes: 47 additions & 29 deletions projects/fal/src/fal/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from fal.logging.isolate import IsolateLogPrinter
from fal.logging.trace import get_tracer
from fal.rest_client import REST_CLIENT
from fal.sdk import KeyScope
from fal.sdk import AliasInfo, KeyScope
from isolate.logs import Log, LogLevel, LogSource
from rich.table import Table

Expand Down Expand Up @@ -280,13 +280,11 @@ def register_application(
"Must expose port 8080 for now. This will be configurable in the future."
)

max_concurrency = gateway_options.get("max_concurrency")
id = host.register(
func=isolated_function.func,
options=isolated_function.options,
application_name=alias,
application_auth_mode=auth_mode,
max_concurrency=max_concurrency,
metadata={},
)

Expand Down Expand Up @@ -342,57 +340,77 @@ def alias_cli(ctx, host: str, port: str):
ctx.obj = api.FalServerlessClient(f"{host}:{port}")


def _alias_table(aliases: list[AliasInfo]):
table = Table(title="Function Aliases")
table.add_column("Alias")
table.add_column("Revision")
table.add_column("Auth")
table.add_column("Max Concurrency")
table.add_column("Max Multiplexing")
table.add_column("Keep Alive")

for app_alias in aliases:
table.add_row(
app_alias.alias,
app_alias.revision,
app_alias.auth_mode,
str(app_alias.max_concurrency),
str(app_alias.max_multiplexing),
str(app_alias.keep_alive),
)

return table


@alias_cli.command("list")
@click.pass_obj
def alias_list(client: api.FalServerlessClient):
with client.connect() as connection:
table = Table(title="Function Aliases")
table.add_column("Alias")
table.add_column("Revision")
table.add_column("Auth")
table.add_column("Max Concurrency")

for app_alias in connection.list_aliases():
table.add_row(
app_alias.alias,
app_alias.revision,
app_alias.auth_mode,
str(app_alias.max_concurrency),
)
aliases = connection.list_aliases()
table = _alias_table(aliases)

console.print(table)


@alias_cli.command("scale")
@click.argument("alias", required=True)
@click.argument("max_concurrency", required=True, type=int)
@click.pass_obj
def alias_scale(client: api.FalServerlessClient, alias: str, max_concurrency: int):
with client.connect() as connection:
connection.scale(application_name=alias, max_concurrency=max_concurrency)


@alias_cli.command("update")
@click.argument("alias", required=True)
@click.option("--keep-alive", type=int)
@click.option("--max-multiplexing", type=int)
@click.option("--keep-alive", "-k", type=int)
@click.option("--max-multiplexing", "-m", type=int)
@click.option("--max-concurrency", "-c", type=int)
@click.pass_obj
def alias_update(
client: api.FalServerlessClient,
alias: str,
keep_alive: int | None,
max_multiplexing: int | None,
max_concurrency: int | None,
):
with client.connect() as connection:
if not (keep_alive or max_multiplexing):
if keep_alive is None and max_multiplexing is None and max_concurrency is None:
console.log("No parameters for update were provided, ignoring.")
return

connection.update_application(
alias_info = connection.update_application(
application_name=alias,
keep_alive=keep_alive,
max_multiplexing=max_multiplexing,
max_concurrency=max_concurrency,
)
table = _alias_table([alias_info])

console.print(table)


@alias_cli.command("scale")
@click.argument("alias", required=True)
@click.argument("max_concurrency", required=True, type=int)
def alias_scale(alias: str, max_concurrency: int):
alias_update.callback(
alias=alias,
keep_alive=None,
max_multiplexing=None,
max_concurrency=max_concurrency,
) # type: ignore


##### Secrets group #####
Expand Down
26 changes: 16 additions & 10 deletions projects/fal/src/fal/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ class AliasInfo:
alias: str
revision: str
auth_mode: str
keep_alive: int
max_concurrency: int
max_multiplexing: int


@dataclass
Expand Down Expand Up @@ -258,7 +260,9 @@ def _from_grpc_alias_info(message: isolate_proto.AliasInfo) -> AliasInfo:
alias=message.alias,
revision=message.revision,
auth_mode=auth_mode,
keep_alive=message.keep_alive,
max_concurrency=message.max_concurrency,
max_multiplexing=message.max_multiplexing,
)


Expand Down Expand Up @@ -306,7 +310,8 @@ class MachineRequirements:
exposed_port: int | None = None
scheduler: str | None = None
scheduler_options: dict[str, Any] | None = None
max_multiplexing: int = FAL_SERVERLESS_DEFAULT_MAX_MULTIPLEXING
max_concurrency: int | None = None
max_multiplexing: int | None = None


@dataclass
Expand Down Expand Up @@ -386,7 +391,6 @@ def register(
application_name: str | None = None,
application_auth_mode: Literal["public", "private", "shared"] | None = None,
*,
max_concurrency: int | None = None,
serialization_method: str = _DEFAULT_SERIALIZATION_METHOD,
machine_requirements: MachineRequirements | None = None,
metadata: dict[str, Any] | None = None,
Expand All @@ -402,6 +406,7 @@ def register(
scheduler_options=to_struct(
machine_requirements.scheduler_options or {}
),
max_concurrency=machine_requirements.max_concurrency,
max_multiplexing=machine_requirements.max_multiplexing,
)
else:
Expand All @@ -423,7 +428,6 @@ def register(
function=wrapped_function,
environments=environments,
machine_requirements=wrapped_requirements,
max_concurrency=max_concurrency,
application_name=application_name,
auth_mode=auth_mode,
metadata=struct_metadata,
Expand All @@ -432,24 +436,25 @@ def register(
yield from_grpc(partial_result)

def scale(self, application_name: str, max_concurrency: int | None = None) -> None:
request = isolate_proto.ScaleApplicationRequest(
application_name=application_name,
max_concurrency=max_concurrency,
)
self.stub.ScaleApplication(request)
raise NotImplementedError

def update_application(
self,
application_name: str,
keep_alive: int | None = None,
max_multiplexing: int | None = None,
) -> None:
max_concurrency: int | None = None,
) -> AliasInfo:
request = isolate_proto.UpdateApplicationRequest(
application_name=application_name,
keep_alive=keep_alive,
max_multiplexing=max_multiplexing,
max_concurrency=max_concurrency,
)
res: isolate_proto.UpdateApplicationResult = self.stub.UpdateApplication(
request
)
self.stub.UpdateApplication(request)
return from_grpc(res.alias_info)

def run(
self,
Expand All @@ -471,6 +476,7 @@ def run(
scheduler_options=to_struct(
machine_requirements.scheduler_options or {}
),
max_concurrency=machine_requirements.max_concurrency,
max_multiplexing=machine_requirements.max_multiplexing,
)
else:
Expand Down
2 changes: 1 addition & 1 deletion projects/fal/tests/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class Output(BaseModel):
keep_alive=60,
machine_type="S",
serve=True,
max_concurrency=1,
)
def addition_app(input: Input) -> Output:
print("starting...")
Expand All @@ -41,7 +42,6 @@ def addition_app(input: Input) -> Output:
app_alias = addition_app.host.register(
func=addition_app.func,
options=addition_app.options,
max_concurrency=1,
)
user_id = _get_user_id()
yield f"{user_id}-{app_alias}"
Expand Down

0 comments on commit 395d211

Please sign in to comment.