-
Notifications
You must be signed in to change notification settings - Fork 2.2k
feat: 文档状态70% #1674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
feat: 文档状态70% #1674
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,24 +9,29 @@ | |
import datetime | ||
import logging | ||
import os | ||
import threading | ||
import traceback | ||
from typing import List | ||
|
||
import django.db.models | ||
from django.db import models | ||
from django.db.models import QuerySet | ||
from django.db.models.functions import Substr, Reverse | ||
from langchain_core.embeddings import Embeddings | ||
|
||
from common.config.embedding_config import VectorStore | ||
from common.db.search import native_search, get_dynamics_model | ||
from common.event.common import embedding_poxy | ||
from common.db.search import native_search, get_dynamics_model, native_update | ||
from common.db.sql_execute import sql_execute, update_execute | ||
from common.util.file_util import get_file_content | ||
from common.util.lock import try_lock, un_lock | ||
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping | ||
from common.util.page_utils import page | ||
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping, TaskType, State | ||
from embedding.models import SourceType, SearchMode | ||
from smartdoc.conf import PROJECT_DIR | ||
|
||
max_kb_error = logging.getLogger(__file__) | ||
max_kb = logging.getLogger(__file__) | ||
lock = threading.Lock() | ||
|
||
|
||
class SyncWebDatasetArgs: | ||
|
@@ -114,7 +119,8 @@ def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings): | |
@param embedding_model: 向量模型 | ||
""" | ||
max_kb.info(f"开始--->向量化段落:{paragraph_id}") | ||
status = Status.success | ||
# 更新到开始状态 | ||
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, State.STARTED) | ||
try: | ||
data_list = native_search( | ||
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter( | ||
|
@@ -125,23 +131,114 @@ def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings): | |
# 删除段落 | ||
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) | ||
|
||
def is_save_function(): | ||
return QuerySet(Paragraph).filter(id=paragraph_id).exists() | ||
def is_the_task_interrupted(): | ||
_paragraph = QuerySet(Paragraph).filter(id=paragraph_id).first() | ||
if _paragraph is None or Status(_paragraph.status)[TaskType.EMBEDDING] == State.REVOKE: | ||
return True | ||
return False | ||
|
||
# 批量向量化 | ||
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) | ||
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_the_task_interrupted) | ||
# 更新到开始状态 | ||
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, | ||
State.SUCCESS) | ||
except Exception as e: | ||
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}') | ||
status = Status.error | ||
ListenerManagement.update_status(QuerySet(Paragraph).filter(id=paragraph_id), TaskType.EMBEDDING, | ||
State.FAILURE) | ||
finally: | ||
QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': status}) | ||
max_kb.info(f'结束--->向量化段落:{paragraph_id}') | ||
|
||
@staticmethod | ||
def embedding_by_data_list(data_list: List, embedding_model: Embeddings): | ||
# 批量向量化 | ||
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True) | ||
|
||
@staticmethod | ||
def get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, post_apply=lambda: None): | ||
def embedding_paragraph_apply(paragraph_list): | ||
for paragraph in paragraph_list: | ||
if is_the_task_interrupted(): | ||
break | ||
ListenerManagement.embedding_by_paragraph(str(paragraph.get('id')), embedding_model) | ||
post_apply() | ||
|
||
return embedding_paragraph_apply | ||
|
||
@staticmethod | ||
def get_aggregation_document_status(document_id): | ||
def aggregation_document_status(): | ||
sql = get_file_content( | ||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql')) | ||
native_update({'document_custom_sql': QuerySet(Document).filter(id=document_id), | ||
'default_sql': QuerySet(Document).filter(id=document_id)}, sql, with_table_name=True) | ||
|
||
return aggregation_document_status | ||
|
||
@staticmethod | ||
def get_aggregation_document_status_by_dataset_id(dataset_id): | ||
def aggregation_document_status(): | ||
sql = get_file_content( | ||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql')) | ||
native_update({'document_custom_sql': QuerySet(Document).filter(dataset_id=dataset_id), | ||
'default_sql': QuerySet(Document).filter(dataset_id=dataset_id)}, sql) | ||
|
||
return aggregation_document_status | ||
|
||
@staticmethod | ||
def get_aggregation_document_status_by_query_set(queryset): | ||
def aggregation_document_status(): | ||
sql = get_file_content( | ||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_status_meta.sql')) | ||
native_update({'document_custom_sql': queryset, 'default_sql': queryset}, sql) | ||
|
||
return aggregation_document_status | ||
|
||
@staticmethod | ||
def post_update_document_status(document_id, task_type: TaskType): | ||
_document = QuerySet(Document).filter(id=document_id).first() | ||
|
||
status = Status(_document.status) | ||
if status[task_type] == State.REVOKE: | ||
status[task_type] = State.REVOKED | ||
else: | ||
status[task_type] = State.SUCCESS | ||
for item in _document.status_meta.get('aggs', []): | ||
agg_status = item.get('status') | ||
agg_count = item.get('count') | ||
if Status(agg_status)[task_type] == State.FAILURE and agg_count > 0: | ||
status[task_type] = State.FAILURE | ||
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), task_type, status[task_type]) | ||
|
||
ListenerManagement.update_status(QuerySet(Paragraph).annotate( | ||
reversed_status=Reverse('status'), | ||
task_type_status=Substr('reversed_status', task_type.value, | ||
task_type.value), | ||
).filter(task_type_status=State.REVOKE.value).filter(document_id=document_id).values('id'), | ||
task_type, | ||
State.REVOKED) | ||
|
||
@staticmethod | ||
def update_status(query_set: QuerySet, taskType: TaskType, state: State): | ||
exec_sql = get_file_content( | ||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_paragraph_status.sql')) | ||
bit_number = len(TaskType) | ||
up_index = taskType.value - 1 | ||
next_index = taskType.value + 1 | ||
current_index = taskType.value | ||
status_number = state.value | ||
params_dict = {'${bit_number}': bit_number, '${up_index}': up_index, | ||
'${status_number}': status_number, '${next_index}': next_index, | ||
'${table_name}': query_set.model._meta.db_table, '${current_index}': current_index} | ||
for key in params_dict: | ||
_value_ = params_dict[key] | ||
exec_sql = exec_sql.replace(key, str(_value_)) | ||
lock.acquire() | ||
try: | ||
native_update(query_set, exec_sql) | ||
finally: | ||
lock.release() | ||
|
||
@staticmethod | ||
def embedding_by_document(document_id, embedding_model: Embeddings): | ||
""" | ||
|
@@ -153,33 +250,29 @@ def embedding_by_document(document_id, embedding_model: Embeddings): | |
if not try_lock('embedding' + str(document_id)): | ||
return | ||
max_kb.info(f"开始--->向量化文档:{document_id}") | ||
QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding}) | ||
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding}) | ||
status = Status.success | ||
# 批量修改状态为PADDING | ||
ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.EMBEDDING, State.STARTED) | ||
try: | ||
data_list = native_search( | ||
{'problem': QuerySet( | ||
get_dynamics_model({'paragraph.document_id': django.db.models.CharField()})).filter( | ||
**{'paragraph.document_id': document_id}), | ||
'paragraph': QuerySet(Paragraph).filter(document_id=document_id)}, | ||
select_string=get_file_content( | ||
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) | ||
# 删除文档向量数据 | ||
VectorStore.get_embedding_vector().delete_by_document_id(document_id) | ||
|
||
def is_save_function(): | ||
return QuerySet(Document).filter(id=document_id).exists() | ||
|
||
# 批量向量化 | ||
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) | ||
def is_the_task_interrupted(): | ||
document = QuerySet(Document).filter(id=document_id).first() | ||
if document is None or Status(document.status)[TaskType.EMBEDDING] == State.REVOKE: | ||
return True | ||
return False | ||
|
||
# 根据段落进行向量化处理 | ||
page(QuerySet(Paragraph).filter(document_id=document_id).values('id'), 5, | ||
ListenerManagement.get_embedding_paragraph_apply(embedding_model, is_the_task_interrupted, | ||
ListenerManagement.get_aggregation_document_status( | ||
document_id)), | ||
is_the_task_interrupted) | ||
except Exception as e: | ||
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}') | ||
status = Status.error | ||
finally: | ||
# 修改状态 | ||
QuerySet(Document).filter(id=document_id).update( | ||
**{'status': status, 'update_time': datetime.datetime.now()}) | ||
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status}) | ||
ListenerManagement.post_update_document_status(document_id, TaskType.EMBEDDING) | ||
ListenerManagement.get_aggregation_document_status(document_id)() | ||
max_kb.info(f"结束--->向量化文档:{document_id}") | ||
un_lock('embedding' + str(document_id)) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这段代码是关于自然语言处理(NLP)中使用Django ORM和SQLite数据库查询的一个简单示例。 以下是几点需要改善的地方:
from typing import List
import django.db.models
from django.db import models
from django.db.models.functions import Concat, Coalesce, StrStr
from langchain_core.embeddings import Embeddings
from common.config.embedding_config import VectorStore
from common.db.search import native_search, get_dynamics_model, native_update, sql_execute
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping
from embedding.models import SourceType, SearchMode
from smartdoc.conf import PROJECT_DIR
class MaxKBError(Exception): pass
def embedding_by_paragraph(paragraph: dict):
"""向量化段落
Args:
paragraph (dict): 段落字典数据
Returns:
None : 运行结果存储在paragraph中
"""
@try_lock('embedding_' + str(paragraph['id']))
def run_embedding_by_paragraph(context=None, error_callback=un_lock):
我推荐你对整个代码做深入理解并根据实际项目情况进行调整和改进。希望这些建议有所帮助! |
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# coding=utf-8 | ||
""" | ||
@project: MaxKB | ||
@Author:虎 | ||
@file: page_utils.py | ||
@date:2024/11/21 10:32 | ||
@desc: | ||
""" | ||
from math import ceil | ||
|
||
|
||
def page(query_set, page_size, handler, is_the_task_interrupted=lambda: False): | ||
""" | ||
|
||
@param query_set: 查询query_set | ||
@param page_size: 每次查询大小 | ||
@param handler: 数据处理器 | ||
@param is_the_task_interrupted: 任务是否被中断 | ||
@return: | ||
""" | ||
count = query_set.count() | ||
for i in range(0, ceil(count / page_size)): | ||
if is_the_task_interrupted(): | ||
return | ||
offset = i * page_size | ||
paragraph_list = query_set[offset: offset + page_size] | ||
handler(paragraph_list) |
34 changes: 34 additions & 0 deletions
34
apps/dataset/migrations/0011_document_status_meta_paragraph_status_meta_and_more.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Generated by Django 4.2.15 on 2024-11-22 14:44 | ||
|
||
import dataset.models.data_set | ||
from django.db import migrations, models | ||
|
||
|
||
class Migration(migrations.Migration): | ||
|
||
dependencies = [ | ||
('dataset', '0010_file_meta'), | ||
] | ||
|
||
operations = [ | ||
migrations.AddField( | ||
model_name='document', | ||
name='status_meta', | ||
field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态统计数据'), | ||
), | ||
migrations.AddField( | ||
model_name='paragraph', | ||
name='status_meta', | ||
field=models.JSONField(default=dataset.models.data_set.default_status_meta, verbose_name='状态数据'), | ||
), | ||
migrations.AlterField( | ||
model_name='document', | ||
name='status', | ||
field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'), | ||
), | ||
migrations.AlterField( | ||
model_name='paragraph', | ||
name='status', | ||
field=models.CharField(default=dataset.models.data_set.Status.__str__, max_length=20, verbose_name='状态'), | ||
), | ||
] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上述代码存在以下问题:
QuerySet
应该是 django.db.models.Model 类,并不是 Django ORM 的 QuerySelector 这个东西。UpdateStatus
,BatchSaveDataList
等方法封装成专用的方法。以下是根据上述建议进行的一些改进:
这样就解决了前面提到的所有问题并提供了更好的模板和结构。在实际编写程序过程中,请按照具体实现环境及逻辑来调整这些细节。
最后,由于涉及较多的实际业务需求,上面提供的代码只是一个非常基础和简洁的设计示例,可能需要进一步的业务理解和技术开发知识才能正确地运行和应用这个示例。