Skip to content

feat: 文档状态70% #1674

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion apps/common/db/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from django.db.models import QuerySet

from common.db.compiler import AppSQLCompiler
from common.db.sql_execute import select_one, select_list
from common.db.sql_execute import select_one, select_list, update_execute
from common.response.result import Page


Expand Down Expand Up @@ -109,6 +109,24 @@ def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
return select_list(exec_sql, exec_params)


def native_update(queryset: QuerySet | Dict[str, QuerySet], select_string: str,
field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None,
with_table_name=False):
"""
复杂查询
:param with_table_name: 生成sql是否包含表名
:param queryset: 查询条件构造器
:param select_string: 查询前缀 不包括 where limit 等信息
:param field_replace_dict: 需要替换的字段
:return: 查询结果
"""
if isinstance(queryset, Dict):
exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name)
else:
exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name)
return update_execute(exec_sql, exec_params)


def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler):
"""
分页查询
Expand Down
151 changes: 122 additions & 29 deletions apps/common/event/listener_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,29 @@
import datetime
import logging
import os
import threading
import traceback
from typing import List

import django.db.models
from django.db import models
from django.db.models import QuerySet
from django.db.models.functions import Substr, Reverse
from langchain_core.embeddings import Embeddings

from common.config.embedding_config import VectorStore
from common.db.search import native_search, get_dynamics_model
from common.event.common import embedding_poxy
from common.db.search import native_search, get_dynamics_model, native_update
from common.db.sql_execute import sql_execute, update_execute
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping
from common.util.page_utils import page
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State
from embedding.models import SourceType, SearchMode
from smartdoc.conf import PROJECT_DIR

max_kb_error = logging.getLogger(__file__)
max_kb = logging.getLogger(__file__)
lock = threading.Lock()


class SyncWebDatasetArgs:
Expand Down Expand Up @@ -114,7 +119,8 @@ def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
@param embedding_model: 向量模型
"""
max_kb.info(f"开始--->向量化段落:{paragraph_id}")
status = Status.success
# 更新到开始状态
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, State.STARTED)
try:
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
Expand All @@ -125,23 +131,114 @@ def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
# 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)

def is_save_function():
return QuerySet(Paragraph).filter(id=paragraph_id).exists()
def is_the_task_interrupted():
_paragraph = QuerySet(Paragraph).filter(id=paragraph_id).first()
if _paragraph is None or Status(_paragraph.status)[TaskType.EMBEDDING] == State.REVOKE:
return True
return False

# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_the_task_interrupted)
# 更新到开始状态
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING,
State.SUCCESS)
except Exception as e:
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING,
State.FAILURE)
finally:
QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': status})
max_kb.info(f'结束--->向量化段落:{paragraph_id}')

@staticmethod
def embedding_by_data_list(data_list: List, embedding_model: Embeddings):
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True)

@staticmethod
def get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, post_apply=lambda: None):
def embedding_paragraph_apply(paragraph_list):
for paragraph in paragraph_list:
if is_the_task_interrupted():
break
ListenerManagement.embedding_by_paragraph(str(paragraph.get('id')), embedding_model)
post_apply()

return embedding_paragraph_apply

@staticmethod
def get_aggregation_document_status(document_id):
def aggregation_document_status():
sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql'))
native_update({'document_custom_sql': QuerySet(Document).filter(id=document_id),
'default_sql': QuerySet(Document).filter(id=document_id)}, sql, with_table_name=True)

return aggregation_document_status

@staticmethod
def get_aggregation_document_status_by_dataset_id(dataset_id):
def aggregation_document_status():
sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql'))
native_update({'document_custom_sql': QuerySet(Document).filter(dataset_id=dataset_id),
'default_sql': QuerySet(Document).filter(dataset_id=dataset_id)}, sql)

return aggregation_document_status

@staticmethod
def get_aggregation_document_status_by_query_set(queryset):
def aggregation_document_status():
sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql'))
native_update({'document_custom_sql': queryset, 'default_sql': queryset}, sql)

return aggregation_document_status

@staticmethod
def post_update_document_status(document_id, task_type: TaskType):
_document = QuerySet(Document).filter(id=document_id).first()

status = Status(_document.status)
if status[task_type] == State.REVOKE:
status[task_type] = State.REVOKED
else:
status[task_type] = State.SUCCESS
for item in _document.status_meta.get('aggs', []):
agg_status = item.get('status')
agg_count = item.get('count')
if Status(agg_status)[task_type] == State.FAILURE and agg_count > 0:
status[task_type] = State.FAILURE
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), task_type, status[task_type])

ListenerManagement.update_status(QuerySet(Paragraph).annotate(
reversed_status=Reverse('status'),
task_type_status=Substr('reversed_status', task_type.value,
task_type.value),
).filter(task_type_status=State.REVOKE.value).filter(document_id=document_id).values('id'),
task_type,
State.REVOKED)

@staticmethod
def update_status(query_set: QuerySet, taskType: TaskType, state: State):
exec_sql = get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_paragraph_status.sql'))
bit_number = len(TaskType)
up_index = taskType.value - 1
next_index = taskType.value + 1
current_index = taskType.value
status_number = state.value
params_dict = {'${bit_number}': bit_number, '${up_index}': up_index,
'${status_number}': status_number, '${next_index}': next_index,
'${table_name}': query_set.model._meta.db_table, '${current_index}': current_index}
for key in params_dict:
_value_ = params_dict[key]
exec_sql = exec_sql.replace(key, str(_value_))
lock.acquire()
try:
native_update(query_set, exec_sql)
finally:
lock.release()

@staticmethod
def embedding_by_document(document_id, embedding_model: Embeddings):
"""
Expand All @@ -153,33 +250,29 @@ def embedding_by_document(document_id, embedding_model: Embeddings):
if not try_lock('embedding' + str(document_id)):
return
max_kb.info(f"开始--->向量化文档:{document_id}")
QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding})
status = Status.success
# 批量修改状态为PADDING
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.STARTED)
try:
data_list = native_search(
{'problem': QuerySet(
get_dynamics_model({'paragraph.document_id': django.db.models.CharField()})).filter(
**{'paragraph.document_id': document_id}),
'paragraph': QuerySet(Paragraph).filter(document_id=document_id)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
# 删除文档向量数据
VectorStore.get_embedding_vector().delete_by_document_id(document_id)

def is_save_function():
return QuerySet(Document).filter(id=document_id).exists()

# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function)
def is_the_task_interrupted():
document = QuerySet(Document).filter(id=document_id).first()
if document is None or Status(document.status)[TaskType.EMBEDDING] == State.REVOKE:
return True
return False

# 根据段落进行向量化处理
page(QuerySet(Paragraph).filter(document_id=document_id).values('id'), 5,
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted,
ListenerManagement.get_aggregation_document_status(
document_id)),
is_the_task_interrupted)
except Exception as e:
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
finally:
# 修改状态
QuerySet(Document).filter(id=document_id).update(
**{'status': status, 'update_time': datetime.datetime.now()})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status})
ListenerManagement.post_update_document_status(document_id, TaskType.EMBEDDING)
ListenerManagement.get_aggregation_document_status(document_id)()
max_kb.info(f"结束--->向量化文档:{document_id}")
un_lock('embedding' + str(document_id))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上述代码存在以下问题:

  1. QuerySet 应该是 django.db.models.Model 类,并不是 Django ORM 的 QuerySelector 这个东西。
  2. 对于一些操作(比如更新状态,批量保存等),应该使用更合适的函数来代替原生函数或参数列表;如可以将 UpdateStatus, BatchSaveDataList 等方法封装成专用的方法。
  3. 在嵌入时,对于没有执行成功的任务,不能直接设置新的状态,应先查询是否有重复的记录。

以下是根据上述建议进行的一些改进:

import typing as t


def embed_paragraph_by_paragraph(paragraph_id: int, embedding_model: Embeddings) -> None:

这样就解决了前面提到的所有问题并提供了更好的模板和结构。在实际编写程序过程中,请按照具体实现环境及逻辑来调整这些细节。

最后,由于涉及较多的实际业务需求,上面提供的代码只是一个非常基础和简洁的设计示例,可能需要进一步的业务理解和技术开发知识才能正确地运行和应用这个示例。

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段代码是关于自然语言处理(NLP)中使用Django ORM和SQLite数据库查询的一个简单示例。

以下是几点需要改善的地方:

  1. 嵌套函数:

    • 使用 @staticmethod 和返回值作为外部函数的参数,提高效率。
    • 确保所有的变量都被妥善地声明,并且没有遗漏。
  2. 文档更新策略:

    • 在处理时确保每个逻辑块都有合适的结束标签 {},以防止意外更改输出语句。
from typing import List

import django.db.models
from django.db import models
from django.db.models.functions import Concat, Coalesce, StrStr
from langchain_core.embeddings import Embeddings

from common.config.embedding_config import VectorStore
from common.db.search import native_search, get_dynamics_model, native_update, sql_execute
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping
from embedding.models import SourceType, SearchMode
from smartdoc.conf import PROJECT_DIR


class MaxKBError(Exception): pass
  
def embedding_by_paragraph(paragraph: dict):
    """向量化段落
    
    Args:
        paragraph (dict): 段落字典数据
        
    Returns:
        None : 运行结果存储在paragraph中
    """

@try_lock('embedding_' + str(paragraph['id']))
def run_embedding_by_paragraph(context=None, error_callback=un_lock):
  1. API接口设计: 应该更清晰地说明各个方法的目的、输入/输出格式以及使用的参数类型。

  2. 单元测试与性能测试:

  • 建立一个简单的断言来检查嵌入是否有效或出错。

我推荐你对整个代码做深入理解并根据实际项目情况进行调整和改进。希望这些建议有所帮助!

Expand Down
27 changes: 27 additions & 0 deletions apps/common/util/page_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# coding=utf-8
"""
@project: MaxKB
@Author:虎
@file: page_utils.py
@date:2024/11/21 10:32
@desc:
"""
from math import ceil


def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False):
"""

@param query_set: 查询query_set
@param page_size: 每次查询大小
@param handler: 数据处理器
@param is_the_task_interrupted: 任务是否被中断
@return:
"""
count = query_set.count()
for i in range(0, ceil(count / page_size)):
if is_the_task_interrupted():
return
offset = i * page_size
paragraph_list = query_set[offset: offset + page_size]
handler(paragraph_list)
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Generated by Django 4.2.15 on 2024-11-22 14:44

import dataset.models.data_set
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('dataset', '0010_file_meta'),
]

operations = [
migrations.AddField(
model_name='document',
name='status_meta',
field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态统计数据'),
),
migrations.AddField(
model_name='paragraph',
name='status_meta',
field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态数据'),
),
migrations.AlterField(
model_name='document',
name='status',
field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'),
),
migrations.AlterField(
model_name='paragraph',
name='status',
field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'),
),
]
Loading