From 11f549600db9ba4be72947d8a658cd8c40d34f79 Mon Sep 17 00:00:00 2001 From: gogoswift <48036113@qq.com> Date: Thu, 31 Oct 2024 11:32:44 +0800 Subject: [PATCH 1/5] Update operate.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修正输出格式与合并函数 --- lightrag/operate.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index 8a6820f5..ef0d3398 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -15,6 +15,7 @@ pack_user_ass_to_openai_messages, split_string_by_multi_markers, truncate_list_by_token_size, + process_combine_contexts, ) from .base import ( BaseGraphStorage, @@ -1003,35 +1004,28 @@ def extract_sections(context): ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context) # Combine and deduplicate the entities - combined_entities_set = set( - filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n")) - ) - combined_entities = "\n".join(combined_entities_set) - + combined_entities = process_combine_contexts(hl_entities, ll_entities) + # Combine and deduplicate the relationships - combined_relationships_set = set( - filter( - None, - hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"), - ) - ) - combined_relationships = "\n".join(combined_relationships_set) + combined_relationships = process_combine_contexts(hl_relationships, ll_relationships) # Combine and deduplicate the sources - combined_sources_set = set( - filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n")) - ) - combined_sources = "\n".join(combined_sources_set) + combined_sources = process_combine_contexts(hl_sources, ll_sources) # Format the combined context return f""" -----Entities----- ```csv {combined_entities} +``` -----Relationships----- +```csv {combined_relationships} +``` -----Sources----- +```csv {combined_sources} +`` """ From e43446978b8ddd8958d45a6f1c7c68d28e242e5a Mon Sep 17 00:00:00 2001 From: gogoswift <48036113@qq.com> Date: Thu, 31 Oct 2024 11:34:01 +0800 Subject: [PATCH 2/5] Update utils.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 混合检索的合并函数 --- lightrag/utils.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/lightrag/utils.py b/lightrag/utils.py index 0da4a51a..3daefb88 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -244,3 +244,40 @@ def xml_to_json(xml_file): except Exception as e: print(f"An error occurred: {e}") return None + +#混合检索中的合并函数 +def process_combine_contexts(hl, ll): + header = None + list_hl = hl.strip().split("\n") + list_ll = ll.strip().split("\n") + # 去掉第一个元素(如果不为空) + if list_hl: + header=list_hl[0] + list_hl = list_hl[1:] + if list_ll: + header = list_ll[0] + list_ll = list_ll[1:] + if header is None: + return "" + + # 去掉每个子元素中逗号分隔后的第一个元素(如果不为空) + if list_hl: + list_hl = [','.join(item.split(',')[1:]) for item in list_hl if item] + if list_ll: + list_ll = [','.join(item.split(',')[1:]) for item in list_ll if item] + + # 合并并去重 + combined_sources_set = set( + filter(None, list_hl + list_ll) + ) + + # 创建包含头部的新列表 + combined_sources = [header] + # 为 combined_sources_set 中的每个元素添加自增数字 + for i, item in enumerate(combined_sources_set, start=1): + combined_sources.append(f"{i},\t{item}") + + # 将列表转换为字符串,元素之间用换行符分隔 + combined_sources = "\n".join(combined_sources) + + return combined_sources From 00d570e509c019f9030a344d6335cb568f0c19de Mon Sep 17 00:00:00 2001 From: gogoswift <48036113@qq.com> Date: Thu, 31 Oct 2024 11:52:06 +0800 Subject: [PATCH 3/5] Update operate.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 消除text_units_section_list源文件的换行,统一混合检索上下文合并时的转换 --- lightrag/operate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index ef0d3398..3d7a752e 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -523,7 +523,7 @@ async def _build_local_query_context( text_units_section_list = [["id", "content"]] for i, t in enumerate(use_text_units): - text_units_section_list.append([i, t["content"]]) + text_units_section_list.append([i, t["content"].replace("\n", "").replace("\r", "")]) text_units_context = list_of_list_to_csv(text_units_section_list) return f""" -----Entities----- @@ -788,7 +788,7 @@ async def _build_global_query_context( text_units_section_list = [["id", "content"]] for i, t in enumerate(use_text_units): - text_units_section_list.append([i, t["content"]]) + text_units_section_list.append([i, t["content"].replace("\n", "").replace("\r", "")]) text_units_context = list_of_list_to_csv(text_units_section_list) return f""" From 7d884b97832b275dab3e87ae1692fbe06d95f9d0 Mon Sep 17 00:00:00 2001 From: gogoswift <48036113@qq.com> Date: Thu, 31 Oct 2024 14:31:26 +0800 Subject: [PATCH 4/5] Update utils.py --- lightrag/utils.py | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/lightrag/utils.py b/lightrag/utils.py index 3daefb88..7b17cbb6 100644 --- a/lightrag/utils.py +++ b/lightrag/utils.py @@ -1,5 +1,7 @@ import asyncio import html +import io +import csv import json import logging import os @@ -7,7 +9,7 @@ from dataclasses import dataclass from functools import wraps from hashlib import md5 -from typing import Any, Union +from typing import Any, Union,List import xml.etree.ElementTree as ET import numpy as np @@ -175,10 +177,21 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: return list_data -def list_of_list_to_csv(data: list[list]): - return "\n".join( - [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data] - ) +# def list_of_list_to_csv(data: list[list]): +# return "\n".join( +# [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data] +# ) +def list_of_list_to_csv(data: List[List[str]]) -> str: + output = io.StringIO() + writer = csv.writer(output) + writer.writerows(data) + return output.getvalue() +def csv_string_to_list(csv_string: str) -> List[List[str]]: + output = io.StringIO(csv_string) + reader = csv.reader(output) + return [row for row in reader] + + def save_data_to_file(data, file_name): @@ -248,8 +261,8 @@ def xml_to_json(xml_file): #混合检索中的合并函数 def process_combine_contexts(hl, ll): header = None - list_hl = hl.strip().split("\n") - list_ll = ll.strip().split("\n") + list_hl = csv_string_to_list(hl.strip()) + list_ll = csv_string_to_list(ll.strip()) # 去掉第一个元素(如果不为空) if list_hl: header=list_hl[0] @@ -259,12 +272,11 @@ def process_combine_contexts(hl, ll): list_ll = list_ll[1:] if header is None: return "" - - # 去掉每个子元素中逗号分隔后的第一个元素(如果不为空) + # 去掉每个子元素中的第一个元素(如果不为空),再转为一维数组,用于合并去重 if list_hl: - list_hl = [','.join(item.split(',')[1:]) for item in list_hl if item] + list_hl = [','.join(item[1:]) for item in list_hl if item] if list_ll: - list_ll = [','.join(item.split(',')[1:]) for item in list_ll if item] + list_ll = [','.join(item[1:]) for item in list_ll if item] # 合并并去重 combined_sources_set = set( @@ -272,12 +284,12 @@ def process_combine_contexts(hl, ll): ) # 创建包含头部的新列表 - combined_sources = [header] + combined_sources = [",\t".join(header)] # 为 combined_sources_set 中的每个元素添加自增数字 for i, item in enumerate(combined_sources_set, start=1): combined_sources.append(f"{i},\t{item}") - - # 将列表转换为字符串,元素之间用换行符分隔 + + # 将列表转换为字符串,子元素之间用换行符分隔 combined_sources = "\n".join(combined_sources) return combined_sources From 11b94032510062437c2a42a70c20703c5c3023fd Mon Sep 17 00:00:00 2001 From: gogoswift <48036113@qq.com> Date: Thu, 31 Oct 2024 15:17:02 +0800 Subject: [PATCH 5/5] Update operate.py --- lightrag/operate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightrag/operate.py b/lightrag/operate.py index 3d7a752e..ef0d3398 100644 --- a/lightrag/operate.py +++ b/lightrag/operate.py @@ -523,7 +523,7 @@ async def _build_local_query_context( text_units_section_list = [["id", "content"]] for i, t in enumerate(use_text_units): - text_units_section_list.append([i, t["content"].replace("\n", "").replace("\r", "")]) + text_units_section_list.append([i, t["content"]]) text_units_context = list_of_list_to_csv(text_units_section_list) return f""" -----Entities----- @@ -788,7 +788,7 @@ async def _build_global_query_context( text_units_section_list = [["id", "content"]] for i, t in enumerate(use_text_units): - text_units_section_list.append([i, t["content"].replace("\n", "").replace("\r", "")]) + text_units_section_list.append([i, t["content"]]) text_units_context = list_of_list_to_csv(text_units_section_list) return f"""