Skip to content

set_handlers: changed type of models_to_fetch, removed "models_download_params" #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ All notable changes to this project will be documented in this file.
### Changed

- set_handlers: `enabled_handler`, `heartbeat_handler`, `init_handler` now can be async(Coroutines). #175 #181
- set_handlers: `models_to_fetch` and `models_download_params` united in one more flexible parameter. #184
- drop Python 3.9 support. #180
- internal code refactoring and clean-up #177

Expand Down
2 changes: 1 addition & 1 deletion docs/NextcloudTalkBotTransformers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ This library also provides an additional functionality over this endpoint for ea

@asynccontextmanager
async def lifespan(_app: FastAPI):
set_handlers(APP, enabled_handler, models_to_fetch=[MODEL_NAME])
set_handlers(APP, enabled_handler, models_to_fetch={MODEL_NAME:{}})
yield

This will automatically download models specified in ``models_to_fetch`` parameter to the application persistent storage.
Expand Down
2 changes: 1 addition & 1 deletion examples/as_app/talk_bot_ai/lib/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@asynccontextmanager
async def lifespan(_app: FastAPI):
set_handlers(APP, enabled_handler, models_to_fetch=[MODEL_NAME])
set_handlers(APP, enabled_handler, models_to_fetch={MODEL_NAME: {}})
yield


Expand Down
18 changes: 6 additions & 12 deletions nc_py_api/ex_app/integration_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def set_handlers(
enabled_handler: typing.Callable[[bool, AsyncNextcloudApp | NextcloudApp], typing.Awaitable[str] | str],
heartbeat_handler: typing.Callable[[], typing.Awaitable[str] | str] | None = None,
init_handler: typing.Callable[[AsyncNextcloudApp | NextcloudApp], typing.Awaitable[None] | None] | None = None,
models_to_fetch: list[str] | None = None,
models_download_params: dict | None = None,
models_to_fetch: dict[str, dict] | None = None,
map_app_static: bool = True,
):
"""Defines handlers for the application.
Expand All @@ -92,7 +91,6 @@ def set_handlers(

.. note:: ```huggingface_hub`` package should be present for automatic models fetching.

:param models_download_params: Parameters to pass to ``snapshot_download`` function from **huggingface_hub**.
:param map_app_static: Should be folders ``js``, ``css``, ``l10n``, ``img`` automatically mounted in FastAPI or not.

.. note:: First, presence of these directories in the current working dir is checked, then one directory higher.
Expand Down Expand Up @@ -140,8 +138,7 @@ async def init_callback(
background_tasks.add_task(
__fetch_models_task,
nc,
models_to_fetch if models_to_fetch else [],
models_download_params if models_download_params else {},
models_to_fetch if models_to_fetch else {},
)
return responses.JSONResponse(content={}, status_code=200)

Expand Down Expand Up @@ -181,8 +178,7 @@ def __map_app_static_folders(fast_api_app: FastAPI):

def __fetch_models_task(
nc: NextcloudApp,
models: list[str],
params: dict[str, typing.Any],
models: dict[str, dict],
) -> None:
if models:
from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
Expand All @@ -193,10 +189,8 @@ def display(self, msg=None, pos=None):
nc.set_init_status(min(int((self.n * 100 / self.total) / len(models)), 100))
return super().display(msg, pos)

if "max_workers" not in params:
params["max_workers"] = 2
if "cache_dir" not in params:
params["cache_dir"] = persistent_storage()
for model in models:
snapshot_download(model, tqdm_class=TqdmProgress, **params) # noqa
workers = models[model].pop("max_workers", 2)
cache = models[model].pop("cache_dir", persistent_storage())
snapshot_download(model, tqdm_class=TqdmProgress, **models[model], max_workers=workers, cache_dir=cache)
nc.set_init_status(100)
2 changes: 1 addition & 1 deletion tests/_install_init_handler_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

@asynccontextmanager
async def lifespan(_app: FastAPI):
ex_app.set_handlers(APP, enabled_handler, models_to_fetch=[MODEL_NAME])
ex_app.set_handlers(APP, enabled_handler, models_to_fetch={MODEL_NAME: {}})
yield


Expand Down
2 changes: 1 addition & 1 deletion tests/actual_tests/nc_app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,4 @@ async def test_set_user_same_value_async(anc_app):

def test_set_handlers_invalid_param(nc_any):
with pytest.raises(ValueError):
set_handlers(None, None, init_handler=set_handlers, models_to_fetch=["some"]) # noqa
set_handlers(None, None, init_handler=set_handlers, models_to_fetch={"some": {}}) # noqa