Skip to content

Commit

Permalink
fix: improve get_model_with_lora_adapters naming
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Jul 24, 2024
1 parent 59022c2 commit 1f3b2ae
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"Model",
"CausalLM",
"Seq2SeqLM",
"get_model",
"get_model_with_lora_adapters",
]

FLASH_ATT_ERROR_MESSAGE = "{} requires Flash Attention enabled models."
Expand Down Expand Up @@ -304,7 +304,7 @@ class ModelType(enum.Enum):
__GLOBALS[data.name] = data.value["type"]


def _get_model(
def get_model(
model_id: str,
lora_adapter_ids: Optional[List[str]],
revision: Optional[str],
Expand Down Expand Up @@ -1124,7 +1124,7 @@ def _get_model(

# get_model wraps the internal _get_model function and adds support for loading adapters
# this provides a post model loading hook to load adapters into the model after the model has been loaded
def get_model(
def get_model_with_lora_adapters(
model_id: str,
lora_adapters: Optional[List[AdapterInfo]],
revision: Optional[str],
Expand All @@ -1137,7 +1137,7 @@ def get_model(
adapter_to_index: Dict[str, int],
):
lora_adapter_ids = [adapter.id for adapter in lora_adapters]
model = _get_model(
model = get_model(
model_id,
lora_adapter_ids,
revision,
Expand Down
4 changes: 2 additions & 2 deletions server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from text_generation_server.cache import Cache
from text_generation_server.interceptor import ExceptionInterceptor
from text_generation_server.models import Model, get_model
from text_generation_server.models import Model, get_model_with_lora_adapters
from text_generation_server.utils.adapter import AdapterInfo

try:
Expand Down Expand Up @@ -226,7 +226,7 @@ async def serve_inner(
server_urls = [local_url]

try:
model = get_model(
model = get_model_with_lora_adapters(
model_id,
lora_adapters,
revision,
Expand Down

0 comments on commit 1f3b2ae

Please sign in to comment.