Skip to content

Commit 7f597b6

Browse files
committed
feat: support ollama rerank
--story=1017862 --user=王孝刚 希望支持在 Ollama 中添加 rerank 模型 issue#2243 https://www.tapd.cn/57709429/s/1655139
1 parent a071d7c commit 7f597b6

File tree

4 files changed

+169
-1
lines changed

4 files changed

+169
-1
lines changed

apps/application/swagger_api/chat_api.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,15 @@ def get_request_params_api():
165165
openapi.Parameter(name='min_trample', in_=openapi.IN_QUERY, type=openapi.TYPE_INTEGER, required=False,
166166
description=_("Minimum number of clicks")),
167167
openapi.Parameter(name='comparer', in_=openapi.IN_QUERY, type=openapi.TYPE_STRING, required=False,
168-
description=_("or|and comparator"))
168+
description=_("or|and comparator")),
169+
openapi.Parameter(name='start_time', in_=openapi.IN_QUERY,
170+
type=openapi.TYPE_STRING,
171+
required=True,
172+
description=_('start time')),
173+
openapi.Parameter(name='end_time', in_=openapi.IN_QUERY,
174+
type=openapi.TYPE_STRING,
175+
required=True,
176+
description=_('End time')),
169177
]
170178

171179

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: embedding.py
6+
@date:2024/7/12 15:10
7+
@desc:
8+
"""
9+
from typing import Dict
10+
11+
from django.utils.translation import gettext as _
12+
13+
from common import forms
14+
from common.exception.app_exception import AppApiException
15+
from common.forms import BaseForm
16+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
17+
from setting.models_provider.impl.ollama_model_provider.model.reranker import OllamaReranker
18+
from langchain_core.documents import BaseDocumentCompressor, Document
19+
from django.utils.translation import gettext_lazy as _, gettext
20+
21+
22+
class OllamaReRankModelCredential(BaseForm, BaseModelCredential):
23+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
24+
raise_exception=False):
25+
if not model_type == 'RERANKER':
26+
raise AppApiException(ValidCode.valid_error.value,
27+
gettext('{model_type} Model type is not supported').format(model_type=model_type))
28+
model_type_list = provider.get_model_type_list()
29+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
30+
raise AppApiException(ValidCode.valid_error.value,
31+
_('{model_type} Model type is not supported').format(model_type=model_type))
32+
try:
33+
model_list = provider.get_base_model_list(model_credential.get('api_base'))
34+
except Exception as e:
35+
raise AppApiException(ValidCode.valid_error.value, _('API domain name is invalid'))
36+
exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
37+
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
38+
if len(exist) == 0:
39+
raise AppApiException(ValidCode.model_not_fount,
40+
_('The model does not exist, please download the model first'))
41+
42+
try:
43+
model: OllamaReranker = provider.get_model(model_type, model_name, model_credential)
44+
model.compress_documents([Document(page_content=gettext('Hello'))], gettext('Hello'))
45+
except Exception as e:
46+
if isinstance(e, AppApiException):
47+
raise e
48+
if raise_exception:
49+
raise AppApiException(ValidCode.valid_error.value,
50+
gettext(
51+
'Verification failed, please check whether the parameters are correct: {error}').format(
52+
error=str(e)))
53+
else:
54+
return False
55+
return True
56+
57+
def encryption_dict(self, model_info: Dict[str, object]):
58+
return model_info
59+
60+
def build_model(self, model_info: Dict[str, object]):
61+
for key in ['model']:
62+
if key not in model_info:
63+
raise AppApiException(500, _('{key} is required').format(key=key))
64+
return self
65+
66+
api_base = forms.TextInputField('API URL', required=True)
67+
api_key = forms.TextInputField('API Key', required=True)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import Sequence, Optional, Any, Dict
2+
from langchain_core.callbacks import Callbacks
3+
from langchain_core.documents import BaseDocumentCompressor, Document
4+
import requests
5+
6+
from setting.models_provider.base_model_provider import MaxKBBaseModel
7+
8+
9+
class OllamaReranker(MaxKBBaseModel, BaseDocumentCompressor):
10+
api_base: Optional[str]
11+
"""URL of the Ollama server"""
12+
model_name: Optional[str]
13+
"""The model name to use for reranking"""
14+
api_key: Optional[str]
15+
16+
@staticmethod
17+
def new_instance(model_name, model_credential: Dict[str, object], **model_kwargs):
18+
return OllamaReranker(api_base=model_credential.get('api_base'), model_name=model_name,
19+
api_key=model_credential.get('api_key'), top_n=model_kwargs.get('top_n', 3))
20+
21+
top_n: Optional[int] = 3
22+
23+
def __init__(
24+
self, api_base: Optional[str] = None, model_name: Optional[str] = None, top_n=3,
25+
api_key: Optional[str] = None
26+
):
27+
super().__init__()
28+
29+
if api_base is None:
30+
raise ValueError("Please provide server URL")
31+
32+
if model_name is None:
33+
raise ValueError("Please provide the model name")
34+
35+
self.api_base = api_base
36+
self.model_name = model_name
37+
self.api_key = api_key
38+
self.top_n = top_n
39+
40+
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
41+
Sequence[Document]:
42+
"""
43+
Given a query and a set of documents, rerank them using Ollama API.
44+
"""
45+
if not documents or len(documents) == 0:
46+
return []
47+
48+
# Prepare the data to send to Ollama API
49+
headers = {
50+
'Authorization': f'Bearer {self.api_key}' # Use API key for authentication if required
51+
}
52+
53+
# Format the documents to be sent in a format understood by Ollama's API
54+
documents_text = [document.page_content for document in documents]
55+
56+
# Make a POST request to Ollama's rerank API endpoint
57+
payload = {
58+
'model': self.model_name, # Specify the model
59+
'query': query,
60+
'documents': documents_text,
61+
'top_n': self.top_n
62+
}
63+
64+
try:
65+
response = requests.post(f'{self.api_base}/v1/rerank', headers=headers, json=payload)
66+
response.raise_for_status()
67+
res = response.json()
68+
69+
# Ensure the response contains expected results
70+
if 'results' not in res:
71+
raise ValueError("The API response did not contain rerank results.")
72+
73+
# Convert the API response into a list of Document objects with relevance scores
74+
ranked_documents = [
75+
Document(page_content=d['text'], metadata={'relevance_score': d['relevance_score']})
76+
for d in res['results']
77+
]
78+
return ranked_documents
79+
80+
except requests.exceptions.RequestException as e:
81+
print(f"Error during API request: {e}")
82+
return [] # Return an empty list if the request failed

apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
from setting.models_provider.impl.ollama_model_provider.credential.embedding import OllamaEmbeddingModelCredential
1919
from setting.models_provider.impl.ollama_model_provider.credential.image import OllamaImageModelCredential
2020
from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential
21+
from setting.models_provider.impl.ollama_model_provider.credential.reranker import OllamaReRankModelCredential
2122
from setting.models_provider.impl.ollama_model_provider.model.embedding import OllamaEmbedding
2223
from setting.models_provider.impl.ollama_model_provider.model.image import OllamaImage
2324
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
25+
from setting.models_provider.impl.ollama_model_provider.model.reranker import OllamaReranker
2426
from smartdoc.conf import PROJECT_DIR
2527
from django.utils.translation import gettext as _
2628

@@ -153,12 +155,19 @@
153155
]
154156
ollama_embedding_model_credential = OllamaEmbeddingModelCredential()
155157
ollama_image_model_credential = OllamaImageModelCredential()
158+
ollama_reranker_model_credential = OllamaReRankModelCredential()
156159
embedding_model_info = [
157160
ModelInfo(
158161
'nomic-embed-text',
159162
_('A high-performance open embedding model with a large token context window.'),
160163
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding),
161164
]
165+
reranker_model_info = [
166+
ModelInfo(
167+
'ollama:reranker',
168+
'',
169+
ModelTypeConst.RERANKER, ollama_reranker_model_credential, OllamaReranker),
170+
]
162171

163172
image_model_info = [
164173
ModelInfo(
@@ -189,6 +198,8 @@
189198
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), )
190199
.append_model_info_list(image_model_info)
191200
.append_default_model_info(image_model_info[0])
201+
.append_model_info_list(reranker_model_info)
202+
.append_default_model_info(reranker_model_info[0])
192203
.build()
193204
)
194205

0 commit comments

Comments
 (0)