@@ -65,12 +65,13 @@ def __init__(
6565 super ().__init__ ()
6666
6767 self .base_model_paths = base_model_paths
68+
6869 self .max_model_len = model_config .max_model_len
6970 self .engine_client = engine_client
7071 self .model_config = model_config
7172
7273 self .static_lora_modules = lora_modules
73- self .lora_requests : list [ LoRARequest ] = []
74+ self .lora_requests : dict [ str , LoRARequest ] = {}
7475 self .lora_id_counter = AtomicCounter (0 )
7576
7677 self .lora_resolvers : list [LoRAResolver ] = []
@@ -138,7 +139,7 @@ async def show_available_models(self) -> ModelList:
138139 parent = lora .base_model_name if lora .base_model_name else
139140 self .base_model_paths [0 ].name ,
140141 permission = [ModelPermission ()])
141- for lora in self .lora_requests
142+ for lora in self .lora_requests . values ()
142143 ]
143144 prompt_adapter_cards = [
144145 ModelCard (id = prompt_adapter .prompt_adapter_name ,
@@ -155,53 +156,60 @@ async def load_lora_adapter(
155156 request : LoadLoRAAdapterRequest ,
156157 base_model_name : Optional [str ] = None
157158 ) -> Union [ErrorResponse , str ]:
158- error_check_ret = await self ._check_load_lora_adapter_request (request )
159- if error_check_ret is not None :
160- return error_check_ret
161-
162- lora_name , lora_path = request .lora_name , request .lora_path
163- unique_id = self .lora_id_counter .inc (1 )
164- lora_request = LoRARequest (lora_name = lora_name ,
165- lora_int_id = unique_id ,
166- lora_path = lora_path )
167- if base_model_name is not None and self .is_base_model (base_model_name ):
168- lora_request .base_model_name = base_model_name
169-
170- # Validate that the adapter can be loaded into the engine
171- # This will also pre-load it for incoming requests
172- try :
173- await self .engine_client .add_lora (lora_request )
174- except BaseException as e :
175- error_type = "BadRequestError"
176- status_code = HTTPStatus .BAD_REQUEST
177- if "No adapter found" in str (e ):
178- error_type = "NotFoundError"
179- status_code = HTTPStatus .NOT_FOUND
180-
181- return create_error_response (message = str (e ),
182- err_type = error_type ,
183- status_code = status_code )
184-
185- self .lora_requests .append (lora_request )
186- logger .info ("Loaded new LoRA adapter: name '%s', path '%s'" , lora_name ,
187- lora_path )
188- return f"Success: LoRA adapter '{ lora_name } ' added successfully."
159+ lora_name = request .lora_name
160+
161+ # Ensure atomicity based on the lora name
162+ async with self .lora_resolver_lock [lora_name ]:
163+ error_check_ret = await self ._check_load_lora_adapter_request (
164+ request )
165+ if error_check_ret is not None :
166+ return error_check_ret
167+
168+ lora_path = request .lora_path
169+ unique_id = self .lora_id_counter .inc (1 )
170+ lora_request = LoRARequest (lora_name = lora_name ,
171+ lora_int_id = unique_id ,
172+ lora_path = lora_path )
173+ if base_model_name is not None and self .is_base_model (
174+ base_model_name ):
175+ lora_request .base_model_name = base_model_name
176+
177+ # Validate that the adapter can be loaded into the engine
178+ # This will also pre-load it for incoming requests
179+ try :
180+ await self .engine_client .add_lora (lora_request )
181+ except Exception as e :
182+ error_type = "BadRequestError"
183+ status_code = HTTPStatus .BAD_REQUEST
184+ if "No adapter found" in str (e ):
185+ error_type = "NotFoundError"
186+ status_code = HTTPStatus .NOT_FOUND
187+
188+ return create_error_response (message = str (e ),
189+ err_type = error_type ,
190+ status_code = status_code )
191+
192+ self .lora_requests [lora_name ] = lora_request
193+ logger .info ("Loaded new LoRA adapter: name '%s', path '%s'" ,
194+ lora_name , lora_path )
195+ return f"Success: LoRA adapter '{ lora_name } ' added successfully."
189196
190197 async def unload_lora_adapter (
191198 self ,
192199 request : UnloadLoRAAdapterRequest ) -> Union [ErrorResponse , str ]:
193- error_check_ret = await self ._check_unload_lora_adapter_request (request
194- )
195- if error_check_ret is not None :
196- return error_check_ret
197-
198200 lora_name = request .lora_name
199- self .lora_requests = [
200- lora_request for lora_request in self .lora_requests
201- if lora_request .lora_name != lora_name
202- ]
203- logger .info ("Removed LoRA adapter: name '%s'" , lora_name )
204- return f"Success: LoRA adapter '{ lora_name } ' removed successfully."
201+
202+ # Ensure atomicity based on the lora name
203+ async with self .lora_resolver_lock [lora_name ]:
204+ error_check_ret = await self ._check_unload_lora_adapter_request (
205+ request )
206+ if error_check_ret is not None :
207+ return error_check_ret
208+
209+ # Safe to delete now since we hold the lock
210+ del self .lora_requests [lora_name ]
211+ logger .info ("Removed LoRA adapter: name '%s'" , lora_name )
212+ return f"Success: LoRA adapter '{ lora_name } ' removed successfully."
205213
206214 async def _check_load_lora_adapter_request (
207215 self , request : LoadLoRAAdapterRequest ) -> Optional [ErrorResponse ]:
@@ -213,8 +221,7 @@ async def _check_load_lora_adapter_request(
213221 status_code = HTTPStatus .BAD_REQUEST )
214222
215223 # Check if the lora adapter with the given name already exists
216- if any (lora_request .lora_name == request .lora_name
217- for lora_request in self .lora_requests ):
224+ if request .lora_name in self .lora_requests :
218225 return create_error_response (
219226 message =
220227 f"The lora adapter '{ request .lora_name } ' has already been "
@@ -227,17 +234,16 @@ async def _check_load_lora_adapter_request(
227234 async def _check_unload_lora_adapter_request (
228235 self ,
229236 request : UnloadLoRAAdapterRequest ) -> Optional [ErrorResponse ]:
230- # Check if either 'lora_name' or 'lora_int_id' is provided
231- if not request .lora_name and not request . lora_int_id :
237+ # Check if 'lora_name' is not provided return an error
238+ if not request .lora_name :
232239 return create_error_response (
233240 message =
234- "either 'lora_name' and 'lora_int_id' needs to be provided." ,
241+ "'lora_name' needs to be provided to unload a LoRA adapter ." ,
235242 err_type = "InvalidUserInput" ,
236243 status_code = HTTPStatus .BAD_REQUEST )
237244
238245 # Check if the lora adapter with the given name exists
239- if not any (lora_request .lora_name == request .lora_name
240- for lora_request in self .lora_requests ):
246+ if request .lora_name not in self .lora_requests :
241247 return create_error_response (
242248 message =
243249 f"The lora adapter '{ request .lora_name } ' cannot be found." ,
@@ -260,9 +266,8 @@ async def resolve_lora(
260266 """
261267 async with self .lora_resolver_lock [lora_name ]:
262268 # First check if this LoRA is already loaded
263- for existing in self .lora_requests :
264- if existing .lora_name == lora_name :
265- return existing
269+ if lora_name in self .lora_requests :
270+ return self .lora_requests [lora_name ]
266271
267272 base_model_name = self .model_config .model
268273 unique_id = self .lora_id_counter .inc (1 )
@@ -279,7 +284,7 @@ async def resolve_lora(
279284
280285 try :
281286 await self .engine_client .add_lora (lora_request )
282- self .lora_requests . append ( lora_request )
287+ self .lora_requests [ lora_name ] = lora_request
283288 logger .info (
284289 "Resolved and loaded LoRA adapter '%s' using %s" ,
285290 lora_name , resolver .__class__ .__name__ )
0 commit comments