Skip to content

Commit

Permalink
feat(fal): automatically lower max/min_concurrency for everyone other…
Browse files Browse the repository at this point in the history
… than falbot
  • Loading branch information
efiop committed Sep 11, 2024
1 parent fa0a4c5 commit 8739f86
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 6 deletions.
4 changes: 4 additions & 0 deletions projects/fal/src/fal/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,9 @@ def access_token(self) -> str:
def bearer_token(self) -> str:
return "Bearer " + self.access_token

@property
def is_falbot(self) -> bool:
return self.info["sub"] == "github|110602490"


USER = UserAccess()
55 changes: 49 additions & 6 deletions projects/fal/src/fal/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,41 @@ def __post_init__(self):
raise ValueError("No machine type provided.")


# NOTE: These values need to be in-sync with the values in the serverless API.
MAX_REGISTERED_MIN_CONCURRENCY = 0
MAX_REGISTERED_MAX_CONCURRENCY = 2


def _get_min_concurrency(min_concurrency) -> int:
if USER.is_falbot:
return min_concurrency

if min_concurrency > MAX_REGISTERED_MIN_CONCURRENCY:
logger.warning(
"min_concurrency must be less than or equal to "
f"{MAX_REGISTERED_MIN_CONCURRENCY} for regular users. "
f"Setting min_concurrency to {MAX_REGISTERED_MIN_CONCURRENCY}."
)
return MAX_REGISTERED_MIN_CONCURRENCY

return min_concurrency


def _get_max_concurrency(max_concurrency) -> int:
if USER.is_falbot:
return max_concurrency

if max_concurrency > MAX_REGISTERED_MAX_CONCURRENCY:
logger.warning(
"max_concurrency must be less than or equal to "
f"{MAX_REGISTERED_MAX_CONCURRENCY} for regular users. "
f"Setting max_concurrency to {MAX_REGISTERED_MAX_CONCURRENCY}."
)
return MAX_REGISTERED_MAX_CONCURRENCY

return max_concurrency


@dataclass
class FalServerlessConnection:
hostname: str
Expand Down Expand Up @@ -511,8 +546,12 @@ def register(
scheduler_options=to_struct(
machine_requirements.scheduler_options or {}
),
max_concurrency=machine_requirements.max_concurrency,
min_concurrency=machine_requirements.min_concurrency,
max_concurrency=_get_max_concurrency(
machine_requirements.max_concurrency
),
min_concurrency=_get_min_concurrency(
machine_requirements.min_concurrency
),
max_multiplexing=machine_requirements.max_multiplexing,
)
else:
Expand Down Expand Up @@ -561,8 +600,8 @@ def update_application(
application_name=application_name,
keep_alive=keep_alive,
max_multiplexing=max_multiplexing,
max_concurrency=max_concurrency,
min_concurrency=min_concurrency,
max_concurrency=_get_max_concurrency(max_concurrency),
min_concurrency=_get_min_concurrency(min_concurrency),
)
res: isolate_proto.UpdateApplicationResult = self.stub.UpdateApplication(
request
Expand Down Expand Up @@ -604,9 +643,13 @@ def run(
scheduler_options=to_struct(
machine_requirements.scheduler_options or {}
),
max_concurrency=machine_requirements.max_concurrency,
max_concurrency=_get_max_concurrency(
machine_requirements.max_concurrency
),
max_multiplexing=machine_requirements.max_multiplexing,
min_concurrency=machine_requirements.min_concurrency,
min_concurrency=_get_min_concurrency(
machine_requirements.min_concurrency
),
)
else:
wrapped_requirements = None
Expand Down

0 comments on commit 8739f86

Please sign in to comment.