Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
isthaison committed Dec 3, 2024
2 parents 5b8809d + 44f5e22 commit 6e7d3f6
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 21 deletions.
8 changes: 2 additions & 6 deletions api/db/services/dialog_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def count():


def llm_id2llm_type(llm_id):
llm_id = llm_id.split("@")[0]
llm_id, _ = TenantLLMService.split_model_name_and_factory(llm_id)
fnm = os.path.join(get_project_base_directory(), "conf")
llm_factories = json.load(open(os.path.join(fnm, "llm_factories.json"), "r"))
for llm_factory in llm_factories["factory_llm_infos"]:
Expand All @@ -132,11 +132,7 @@ def llm_id2llm_type(llm_id):
def chat(dialog, messages, stream=True, **kwargs):
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
st = timer()
tmp = dialog.llm_id.split("@")
fid = None
llm_id = tmp[0]
if len(tmp)>1: fid = tmp[1]

llm_id, fid = TenantLLMService.split_model_name_and_factory(dialog.llm_id)
llm = LLMService.query(llm_name=llm_id) if not fid else LLMService.query(llm_name=llm_id, fid=fid)
if not llm:
llm = TenantLLMService.query(tenant_id=dialog.tenant_id, llm_name=llm_id) if not fid else \
Expand Down
37 changes: 28 additions & 9 deletions api/db/services/llm_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os

from api.db.services.user_service import TenantService
from api.utils.file_utils import get_project_base_directory
from rag.llm import EmbeddingModel, CvModel, ChatModel, RerankModel, Seq2txtModel, TTSModel
from api.db import LLMType
from api.db.db_models import DB
Expand All @@ -36,11 +40,11 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def get_api_key(cls, tenant_id, model_name):
arr = model_name.split("@")
if len(arr) < 2:
objs = cls.query(tenant_id=tenant_id, llm_name=model_name)
mdlnm, fid = TenantLLMService.split_model_name_and_factory(model_name)
if not fid:
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm)
else:
objs = cls.query(tenant_id=tenant_id, llm_name=arr[0], llm_factory=arr[1])
objs = cls.query(tenant_id=tenant_id, llm_name=mdlnm, llm_factory=fid)
if not objs:
return
return objs[0]
Expand All @@ -61,6 +65,23 @@ def get_my_llms(cls, tenant_id):

return list(objs)

@staticmethod
def split_model_name_and_factory(model_name):
arr = model_name.split("@")
if len(arr) < 2:
return model_name, None
if len(arr) > 2:
return "@".join(arr[0:-1]), arr[-1]
try:
fact = json.load(open(os.path.join(get_project_base_directory(), "conf/llm_factories.json"), "r"))["factory_llm_infos"]
fact = set([f["name"] for f in fact])
if arr[-1] not in fact:
return model_name, None
return arr[0], arr[-1]
except Exception as e:
logging.exception(f"TenantLLMService.split_model_name_and_factory got exception: {e}")
return model_name, None

@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type,
Expand All @@ -85,9 +106,7 @@ def model_instance(cls, tenant_id, llm_type,
assert False, "LLM type error"

model_config = cls.get_api_key(tenant_id, mdlnm)
tmp = mdlnm.split("@")
fid = None if len(tmp) < 2 else tmp[1]
mdlnm = tmp[0]
mdlnm, fid = TenantLLMService.split_model_name_and_factory(mdlnm)
if model_config: model_config = model_config.to_dict()
if not model_config:
if llm_type in [LLMType.EMBEDDING, LLMType.RERANK]:
Expand Down Expand Up @@ -168,7 +187,7 @@ def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
else:
assert False, "LLM type error"

llm_name = mdlnm.split("@")[0] if "@" in mdlnm else mdlnm
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(mdlnm)

num = 0
try:
Expand All @@ -179,7 +198,7 @@ def increase_usage(cls, tenant_id, llm_type, used_tokens, llm_name=None):
.where(cls.model.tenant_id == tenant_id, cls.model.llm_factory == tenant_llm.llm_factory, cls.model.llm_name == llm_name)\
.execute()
else:
llm_factory = mdlnm.split("@")[1] if "@" in mdlnm else mdlnm
if not llm_factory: llm_factory = mdlnm
num = cls.model.create(tenant_id=tenant_id, llm_factory=llm_factory, llm_name=llm_name, used_tokens=used_tokens)
except Exception:
logging.exception("TenantLLMService.increase_usage got exception")
Expand Down
8 changes: 6 additions & 2 deletions deepdoc/parser/pdf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,8 +956,12 @@ def __images__(self, fnm, zoomin=3, page_from=0,
enumerate(self.pdf.pages[page_from:page_to])]
self.page_images_x2 = [p.to_image(resolution=72 * zoomin * 2).annotated for i, p in
enumerate(self.pdf.pages[page_from:page_to])]
self.page_chars = [[{**c, 'top': c['top'], 'bottom': c['bottom']} for c in page.dedupe_chars().chars if self._has_color(c)] for page in
self.pdf.pages[page_from:page_to]]
try:
self.page_chars = [[{**c, 'top': c['top'], 'bottom': c['bottom']} for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]]
except Exception as e:
logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}")
self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead.

self.total_page = len(self.pdf.pages)
except Exception:
logging.exception("RAGFlowPdfParser __images__")
Expand Down
3 changes: 1 addition & 2 deletions rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,11 +680,10 @@ def encode(self, texts: list, batch_size=16):
return np.array(res.embeddings), res.total_tokens

def encode_queries(self, text):
res = self.client.embed
res = self.client.embed(
texts=text, model=self.model_name, input_type="query"
)
return np.array(res.embeddings), res.total_tokens
return np.array(res.embeddings)[0], res.total_tokens


class HuggingFaceEmbed(Base):
Expand Down
46 changes: 46 additions & 0 deletions sdk/python/test/test_frontend_api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

DATASET_NAME_LIMIT = 128


def create_dataset(auth, dataset_name):
authorization = {"Authorization": auth}
url = f"{HOST_ADDRESS}/v1/kb/create"
Expand All @@ -27,8 +28,53 @@ def rm_dataset(auth, dataset_id):
res = requests.post(url=url, headers=authorization, json=json)
return res.json()


def update_dataset(auth, json_req):
authorization = {"Authorization": auth}
url = f"{HOST_ADDRESS}/v1/kb/update"
res = requests.post(url=url, headers=authorization, json=json_req)
return res.json()


def upload_file(auth, dataset_id, path):
authorization = {"Authorization": auth}
url = f"{HOST_ADDRESS}/v1/document/upload"
base_name = os.path.basename(path)
json_req = {
"kb_id": dataset_id,
}

file = {
'file': open(f'{path}', 'rb')
}

res = requests.post(url=url, headers=authorization, files=file, data=json_req)
return res.json()

def list_document(auth, dataset_id):
authorization = {"Authorization": auth}
url = f"{HOST_ADDRESS}/v1/document/list?kb_id={dataset_id}"
res = requests.get(url=url, headers=authorization)
return res.json()

def get_docs_info(auth, doc_ids):
authorization = {"Authorization": auth}
json_req = {
"doc_ids": doc_ids
}
url = f"{HOST_ADDRESS}/v1/document/infos"
res = requests.post(url=url, headers=authorization, json=json_req)
return res.json()

def parse_docs(auth, doc_ids):
authorization = {"Authorization": auth}
json_req = {
"doc_ids": doc_ids,
"run": 1
}
url = f"{HOST_ADDRESS}/v1/document/run"
res = requests.post(url=url, headers=authorization, json=json_req)
return res.json()

def parse_file(auth, document_id):
pass
76 changes: 76 additions & 0 deletions sdk/python/test/test_frontend_api/test_chunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from common import HOST_ADDRESS, create_dataset, list_dataset, rm_dataset, update_dataset, upload_file, DATASET_NAME_LIMIT
from common import list_document, get_docs_info, parse_docs
from time import sleep
from timeit import default_timer as timer
import re
import pytest
import random
import string


def test_parse_txt_document(get_auth):
# create dataset
res = create_dataset(get_auth, "test_parse_txt_document")
assert res.get("code") == 0, f"{res.get('message')}"

# list dataset
page_number = 1
dataset_list = []
dataset_id = None
while True:
res = list_dataset(get_auth, page_number)
data = res.get("data").get("kbs")
for item in data:
dataset_id = item.get("id")
dataset_list.append(dataset_id)
if len(dataset_list) < page_number * 150:
break
page_number += 1

filename = 'ragflow_test.txt'
res = upload_file(get_auth, dataset_id, f"../test_sdk_api/test_data/{filename}")
assert res.get("code") == 0, f"{res.get('message')}"

res = list_document(get_auth, dataset_id)

doc_id_list = []
for doc in res['data']['docs']:
doc_id_list.append(doc['id'])

res = get_docs_info(get_auth, doc_id_list)
print(doc_id_list)
doc_count = len(doc_id_list)
res = parse_docs(get_auth, doc_id_list)

start_ts = timer()
while True:
res = get_docs_info(get_auth, doc_id_list)
finished_count = 0
for doc_info in res['data']:
if doc_info['progress'] == 1:
finished_count += 1
if finished_count == doc_count:
break
sleep(1)
print('time cost {:.1f}s'.format(timer() - start_ts))

# delete dataset
for dataset_id in dataset_list:
res = rm_dataset(get_auth, dataset_id)
assert res.get("code") == 0, f"{res.get('message')}"
print(f"{len(dataset_list)} datasets are deleted")
4 changes: 2 additions & 2 deletions web/src/components/page-rank.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ const PageRank = () => {
<Flex gap={20} align="center">
<Flex flex={1}>
<Form.Item
name={['parser_config', 'pagerank']}
name={['pagerank']}
noStyle
initialValue={0}
rules={[{ required: true }]}
Expand All @@ -18,7 +18,7 @@ const PageRank = () => {
</Form.Item>
</Flex>
<Form.Item
name={['parser_config', 'pagerank']}
name={['pagerank']}
noStyle
initialValue={0}
rules={[{ required: true }]}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ export const useFetchKnowledgeConfigurationOnMount = (form: FormInstance) => {
'parser_id',
'language',
'parser_config',
'pagerank',
]),
avatar: fileList,
});
Expand Down

0 comments on commit 6e7d3f6

Please sign in to comment.