From d2c3c937e942614aa2a8130b327267757d54580f Mon Sep 17 00:00:00 2001 From: guanguangpeng Date: Wed, 9 Oct 2024 14:07:17 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20XInference=20?= =?UTF-8?q?=E4=BE=9B=E5=BA=94=E5=95=86=E5=8A=A0=E8=BD=BD=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=97=B6=E7=9A=84=E6=9C=AA=E8=AE=A4=E8=AF=81=E5=AF=BC=E8=87=B4?= =?UTF-8?q?=20API=20=E5=9F=9F=E5=90=8D=E6=97=A0=E6=95=88=E7=9A=84=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../impl/xinference_model_provider/credential/llm.py | 2 +- .../xinference_model_provider/xinference_model_provider.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py index 8a6ad4958fb..dc01c790658 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py @@ -34,7 +34,7 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') try: - model_list = provider.get_base_model_list(model_credential.get('api_base'), model_type) + model_list = provider.get_base_model_list(model_credential.get('api_base'), model_credential.get('api_key'), model_type) except Exception as e: raise AppApiException(ValidCode.valid_error.value, "API 域名无效") exist = provider.get_model_info_by_name(model_list, model_name) diff --git a/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py b/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py index 208a8a6f820..0fbc3cc32fd 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py @@ -371,10 +371,13 @@ def get_model_provide_info(self): 'xinference_icon_svg'))) @staticmethod - def get_base_model_list(api_base, model_type): + def get_base_model_list(api_base, api_key, model_type): base_url = get_base_url(api_base) base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') - r = requests.request(method="GET", url=f"{base_url}/models", timeout=5) + headers = {} + if api_key: + headers['Authorization'] = f"Bearer {api_key}" + r = requests.request(method="GET", url=f"{base_url}/models", headers=headers, timeout=5) r.raise_for_status() model_list = r.json().get('data') return [model for model in model_list if model.get('model_type') == model_type]