Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean qa #81

Merged
merged 2 commits into from
Mar 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions .idea/EmoLLM.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions .idea/aws.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions scripts/qa_generation/Clean_QA.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# 清洗 QA 对
调用qwen去判断当前QA对是否属于心理学范畴,去除非心理学范畴的 QA 对

## Step 1
1. 准备好需要清洗的 QA 对数据
2. 将该数据放进 model 同级 data 文件夹下
3. 根据文件夹名去修改 config/config.py 中的 judge_dir。我个人没有对文件名进行更改,所以我的judge_dir是 judge_dir = os.path.join(data_dir, '数据整合')

## Step 2
1. 运行QA_clean.py即可
2. 清洗完的 QA 对会以 jsonl 的格式存在 data/cleaned 下
111 changes: 111 additions & 0 deletions scripts/qa_generation/QA_clean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
import json
import time
from tqdm import tqdm
import concurrent.futures
from datetime import datetime
import numpy as np

from config.config import result_dir, clean_dir, storage_interval, window_size, overlap_size, multi_process_num
from model.qwen import call_qwen_single_turn, call_qwen_Psychology_QA_Pairs
from util.logger import get_logger
from util.data_loader import get_jsonl_file_paths, get_file_list, get_QA_pairs, get_txt_content, capture_qa, merge_sub_qa_generation, save_to_file

logger = get_logger()


def single_thread_generate(thread_num, interval, model_caller, storage_jsonl_path, contents):

storage_counter = 0
judge_list = []
for content in tqdm(contents):
# print('content: ', content)
try:
# model_caller 函数的作用是调用某个预训练的问答生成模型,传递输入内容 content 给模型,然后获取模型的输出 response
response = model_caller(content)
# print('response: ', response)

if response == '1':
content = json.loads(content)
judge_list.append(content)
storage_counter += 1
else:
continue

# 在达到指定的 interval 后,将 storage_list 中的内容保存到指定的文件 storage_jsonl_path 中
if storage_counter % interval == 0:
save_to_file(storage_jsonl_path, judge_list)
storage_counter = 0
judge_list = []

except Exception as exc:
logger.error("QA generation error : %s" % (exc))

# 最后,如果 storage_list 中还有剩余内容,也会将其保存到文件中。
if judge_list:
save_to_file(storage_jsonl_path, judge_list)
judge_list = []


"""
生成 QA 对
model_name: 可调用的模型名称,暂时只实现了 qwen
interval: 存储间隔,即每隔多少条存一次文件,过密的间隔会增大 IO 开销
"""
def clean_qa(
model_name: str = 'qwen',
interval: int = 10,
):
# current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

if model_name == 'qwen':
model_caller = call_qwen_Psychology_QA_Pairs
else:
logger.warning('This model is currently not supported and will call the default model - qwen.')
model_caller = call_qwen_Psychology_QA_Pairs
model_name = 'qwen'

logger.info(f'The called model is: {model_name}.')
logger.info(f'The storage interval is: {interval}.')

file_lists = get_jsonl_file_paths() # 数据整合文件夹下所有.jsonl文件的地址

for file_path in file_lists:
# 一个jsonl文件的所有QA Pairs
contents = get_QA_pairs(file_path)
# print(contents)

file_name = os.path.basename(file_path)
print(file_name)
storage_jsonl_path = os.path.join(
clean_dir, f'{file_name}')

logger.info(f'The generated QA will be stored in {storage_jsonl_path}.')

contents_array = np.array(contents)
chunks = np.array_split(contents_array, multi_process_num)

# 构建并发参数 list
parameters_list = list()
for thread_num, chunk in enumerate(chunks):
parameters_list.append(
[thread_num, interval, model_caller, storage_jsonl_path, list(chunk)]
)

with concurrent.futures.ThreadPoolExecutor(max_workers=multi_process_num) as executor:
# 循环调用 single_thread_generate 函数,每次赋予参数 parameters
futures = [executor.submit(single_thread_generate, *parameters) for parameters in parameters_list]

for future in concurrent.futures.as_completed(futures):
try:
future.result()
except Exception as exc:
logger.error("Thread generated an exception: %s" % (exc))

merge_sub_qa_generation(result_dir, storage_jsonl_path)


if __name__ == '__main__':
# 创建washed文件夹
os.makedirs('./data/cleaned', exist_ok=True)
clean_qa(interval=storage_interval)
8 changes: 8 additions & 0 deletions scripts/qa_generation/choose_prompt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
你是一名经验丰富的心理咨询师,熟悉心理学相关知识。根据我提供的 QA 对,来判断这个 QA 对是否属于心理学范畴。

标准如下:
- 若当前 QA 对属于心理学范畴,则返回1
- 若当前 QA 对不属于心理学范畴,则返回0


以下是给定的心理学 QA 对内容:
8 changes: 6 additions & 2 deletions scripts/qa_generation/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@
model_dir = os.path.join(base_dir, 'model') # model

# data
data_dir = os.path.join(base_dir, 'data') # data
data_dir = os.path.join(base_dir, 'data')
clean_dir = os.path.join(data_dir, 'cleaned')
judge_dir = os.path.join(data_dir, '数据整合')
result_dir = os.path.join(data_dir, 'generated') # result

# log
log_dir = os.path.join(base_dir, 'log') # log
log_file_path = os.path.join(log_dir, 'log.log') # file

# system prompt
# Prompt内容
system_prompt_file_path = os.path.join(base_dir, 'system_prompt_v2.md') # system prompt
wash_prompt_file_path = os.path.join(base_dir, 'choose_prompt.md')


"""
Expand All @@ -28,11 +32,11 @@
DASHSCOPE_API_KEY = ''



"""
控制参数
"""
storage_interval = 10
window_size = 8
overlap_size = 2
multi_process_num = 3

30 changes: 29 additions & 1 deletion scripts/qa_generation/model/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from config.config import DASHSCOPE_API_KEY
from util.logger import get_logger
from util.prompt_loader import load_system_prompt
from util.prompt_loader import load_system_prompt, load_wash_prompt


dashscope.api_key = DASHSCOPE_API_KEY
Expand Down Expand Up @@ -39,3 +39,31 @@ def call_qwen_single_turn(query: str) -> str:
response.code, response.message
))
return ""


def call_qwen_Psychology_QA_Pairs(query: str) -> str:
messages = [
{
'role': Role.SYSTEM,
'content': load_wash_prompt()
},
{
'role': Role.USER,
'content': query
}
]
response = Generation.call(
model='qwen-max-1201',
messages=messages,
result_format='message',
stream=False,
incremental_output=False
)
if response.status_code == HTTPStatus.OK:
return response.output.choices[0]['message']['content']
else:
logger.error('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
response.request_id, response.status_code,
response.code, response.message
))
return ""
39 changes: 34 additions & 5 deletions scripts/qa_generation/util/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,39 @@
import glob
from typing import List, Dict

from config.config import data_dir
from config.config import data_dir, judge_dir
from util.logger import get_logger

logger = get_logger()


"""
递归获取 数据整合 下的所有 .jsonl 文件列表
"""
def get_jsonl_file_paths() -> List[str]:
json_file_paths = []

# 遍历根目录及其所有子目录
for dirpath, dirnames, filenames in os.walk(judge_dir):
# 对每个文件进行检查
for filename in filenames:
# 使用正则表达式匹配以.jsonl结尾的文件名
if re.search(r'\.jsonl$', filename):
# 构建完整的文件路径并添加到列表中
json_file_path = os.path.join(dirpath, filename)
json_file_paths.append(json_file_path)

return json_file_paths

def get_QA_pairs(json_path):
with open(json_path, 'r', encoding='utf-8') as f:
content = f.read().strip()

# 按照换行符分割字符串
QA_Pairs = content.split('\n')

return QA_Pairs

"""
递归获取 data_dir 下的所有 .txt 文件列表
"""
Expand Down Expand Up @@ -47,7 +75,7 @@ def get_txt_content(
res = []
sentences_amount = len(sentences)
start_index, end_index = 0, sentences_amount - window_size
## check length
# check length
if window_size < overlap_size:
logger.error("window_size must be greater than or equal to overlap_size")
return None
Expand All @@ -56,7 +84,7 @@ def get_txt_content(
return ['\n'.join(sentences)]

for i in range(start_index, end_index + 1, overlap_size):
res.append('\n'.join(sentences[i : i + window_size]))
res.append('\n'.join(sentences[i: i + window_size]))
return res


Expand All @@ -80,6 +108,7 @@ def capture_qa(content: str) -> List[Dict]:
logger.warning("No JSON block found.")
return None


"""
将 storage_list 存入到 storage_jsonl_path
"""
Expand All @@ -88,6 +117,7 @@ def save_to_file(storage_jsonl_path, storage_list):
for item in storage_list:
f.write(json.dumps(item, ensure_ascii=False) + '\n')


"""
将并发产生的文件合并成为一个文件
"""
Expand All @@ -102,5 +132,4 @@ def merge_sub_qa_generation(directory, storage_jsonl_path):
for line in f:
file_contents.append(json.loads(line))
os.remove(file_path)
save_to_file(storage_jsonl_path, file_contents)

save_to_file(storage_jsonl_path, file_contents)
7 changes: 7 additions & 0 deletions scripts/qa_generation/util/prompt_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from config.config import system_prompt_file_path
from config.config import wash_prompt_file_path


def load_system_prompt() -> str:
with open(system_prompt_file_path, 'r', encoding='utf-8') as f:
system_prompt = f.read()
return system_prompt


def load_wash_prompt() -> str:
with open(wash_prompt_file_path, 'r', encoding='utf-8') as f:
wash_prompt = f.read()
return wash_prompt