Skip to content

Commit

Permalink
Fix @ in model name issue. (infiniflow#3821)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

infiniflow#3814

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
  • Loading branch information
KevinHuSh authored Dec 3, 2024
1 parent e66addc commit 7543047
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
8 changes: 2 additions & 6 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def count():


def llm_id2llm_type(llm_id):
llm_id = llm_id.split("@")[0]
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
fnm = os.path.join(get_project_base_directory(), "conf")
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
for llm_factory in llm_factories["factory_llm_infos"]:
Expand All @@ -132,11 +132,7 @@ def llm_id2llm_type(llm_id):
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
st = timer()
tmp = dialog.llm_id.split("@")
fid = None
llm_id = tmp[0]
if len(tmp)>1: fid = tmp[1]

llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
if not llm:
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
Expand Down
37 changes: 28 additions & 9 deletions api/db/services/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os

from api.db.services.user_service import TenantService
from api.utils.file_utils import get_project_base_directory
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
from api.db import LLMType
from api.db.db_models import DB
Expand All @@ -36,11 +40,11 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_name):
arr = model_name.split("@")
if len(arr) < 2:
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
if not fid:
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
else:
objs = cls.query(tenant_id=tenant_id, llm_name=arr[0], llm_factory=arr[1])
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
if not objs:
return
return objs[0]
Expand All @@ -61,6 +65,23 @@ def get_my_llms(cls, tenant_id):

return list(objs)

@staticmethod
def split_model_name_and_factory(model_name):
arr = model_name.split("@")
if len(arr) < 2:
return model_name, None
if len(arr) > 2:
return "@".join(arr[0:-1]), arr[-1]
try:
fact = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
fact = set([f["name"] for f in fact])
if arr[-1] not in fact:
return model_name, None
return arr[0], arr[-1]
except Exception as e:
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
return model_name, None

@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type,
Expand All @@ -85,9 +106,7 @@ def model_instance(cls, tenant_id, llm_type,
assert False, "LLM type error"

model_config = cls.get_api_key(tenant_id, mdlnm)
tmp = mdlnm.split("@")
fid = None if len(tmp) < 2 else tmp[1]
mdlnm = tmp[0]
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
if model_config: model_config = model_config.to_dict()
if not model_config:
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
Expand Down Expand Up @@ -168,7 +187,7 @@ def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
else:
assert False, "LLM type error"

llm_name = mdlnm.split("@")[0] if "@" in mdlnm else mdlnm
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)

num = 0
try:
Expand All @@ -179,7 +198,7 @@ def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
.execute()
else:
llm_factory = mdlnm.split("@")[1] if "@" in mdlnm else mdlnm
if not llm_factory: llm_factory = mdlnm
num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
except Exception:
logging.exception("TenantLLMService.increase_usage got exception")
Expand Down

0 comments on commit 7543047

Please sign in to comment.