@@ -54,7 +54,7 @@ def set_handlers(
54
54
fast_api_app : FastAPI ,
55
55
enabled_handler : typing .Callable [[bool , NextcloudApp ], str ],
56
56
heartbeat_handler : typing .Optional [typing .Callable [[], str ]] = None ,
57
- init_handler : typing .Optional [typing .Callable [[], None ]] = None ,
57
+ init_handler : typing .Optional [typing .Callable [[NextcloudApp ], None ]] = None ,
58
58
models_to_fetch : typing .Optional [list [str ]] = None ,
59
59
models_download_params : typing .Optional [dict ] = None ,
60
60
):
@@ -75,15 +75,15 @@ def set_handlers(
75
75
:param models_download_params: Parameters to pass to ``snapshot_download`` function from **huggingface_hub**.
76
76
"""
77
77
78
- def fetch_models_task (models : list [str ]) -> None :
78
+ def fetch_models_task (nc : NextcloudApp , models : list [str ]) -> None :
79
79
if models :
80
80
from huggingface_hub import snapshot_download # noqa isort:skip pylint: disable=C0415 disable=E0401
81
81
from tqdm import tqdm # noqa isort:skip pylint: disable=C0415 disable=E0401
82
82
83
83
class TqdmProgress (tqdm ):
84
84
def display (self , msg = None , pos = None ):
85
85
if init_handler is None :
86
- NextcloudApp () .set_init_status (min (int ((self .n * 100 / self .total ) / len (models )), 100 ))
86
+ nc .set_init_status (min (int ((self .n * 100 / self .total ) / len (models )), 100 ))
87
87
return super ().display (msg , pos )
88
88
89
89
params = models_download_params if models_download_params else {}
@@ -94,9 +94,9 @@ def display(self, msg=None, pos=None):
94
94
for model in models :
95
95
snapshot_download (model , tqdm_class = TqdmProgress , ** params ) # noqa
96
96
if init_handler is None :
97
- NextcloudApp () .set_init_status (100 )
97
+ nc .set_init_status (100 )
98
98
else :
99
- init_handler ()
99
+ init_handler (nc )
100
100
101
101
@fast_api_app .put ("/enabled" )
102
102
def enabled_callback (
@@ -114,6 +114,9 @@ def heartbeat_callback():
114
114
return responses .JSONResponse (content = {"status" : return_status }, status_code = 200 )
115
115
116
116
@fast_api_app .post ("/init" )
117
- def init_callback (background_tasks : BackgroundTasks ):
118
- background_tasks .add_task (fetch_models_task , models_to_fetch if models_to_fetch else [])
117
+ def init_callback (
118
+ background_tasks : BackgroundTasks ,
119
+ nc : typing .Annotated [NextcloudApp , Depends (nc_app )],
120
+ ):
121
+ background_tasks .add_task (fetch_models_task , nc , models_to_fetch if models_to_fetch else [])
119
122
return responses .JSONResponse (content = {}, status_code = 200 )
0 commit comments