Skip to content

Commit c1f3791

Browse files
committed
add code generation scripts
1 parent 74c12c2 commit c1f3791

File tree

9 files changed

+107
-23
lines changed

9 files changed

+107
-23
lines changed

.gitignore

+2-13
Original file line numberDiff line numberDiff line change
@@ -158,16 +158,5 @@ cython_debug/
158158
# and can be added to the global gitignore or merged into this file. For a more nuclear
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
#.idea/
161-
automatic_gen/data/
162-
query_sft/models/
163-
query_sft/data/
164-
query_dpo/models
165-
dart-math/
166-
evaluation/code/
167-
evaluation/outputs/
168-
evol_instruct/output
169-
temp.py
170-
automatic_gen/*data
171-
baselines
172-
check_solvability_difficulty
173-
test_rm
161+
src/data_generation/generation
162+
src/train_question_generator/models

README.md

+10-8
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,26 @@
99
<a href="https://opennlg.cn/"><img src="https://img.shields.io/badge/Organization-OpenNLG%20Group-blueviolet"></a>
1010
</p>
1111

12-
We introduce ScaleQuest, a scalable, cost-effective, and novel data synthesis method that utilizes small-size open-source models to generate questions from scratch without the need for seed data with complex augmentation constraints.
12+
We introduce ScaleQuest, a scalable, cost-effective, and novel data synthesis method that utilizes small-size open-source models to generate questions from scratch without the need for seed questions with complex augmentation constraints.
1313

1414
![](img/results.png)
1515

1616
This repository contains our complete data synthesis method, including:
1717

18-
1. Training a question generator through question fine-tuning (code in the `qft_train` folder).
19-
2. Constructing preference data (code in the `question_optim` folder) and performing question preference optimization (code in the `qpo_train` folder).
20-
3. Using the trained question generator to synthesize questions (code in the `data_generation` folder).
21-
4. Applying a filtering process to the generated questions (code in the `question_filtering` folder).
22-
5. Generating responses (code in the `data_generation` folder) and applying a reward filtering strategy (code in the `reward_filtering` folder).
23-
6. For instruction-tuning and evaluation, we directly use the DART-Math framework.
18+
```
19+
2420
2521
We randomly sampled 100 generated data points and placed them in `data_samples/samples.jsonl`
2622
2723
## Method Overview
2824
2925
![](img/method.png)
3026
31-
27+
1. Training a question generator
28+
- through question fine-tuning (code in the `qft_train` folder).
29+
- Constructing preference data (code in the `question_optim` folder) and performing question preference optimization (code in the `qpo_train` folder).
30+
2. Using the trained question generator to synthesize questions (code in the `data_generation` folder).
31+
- Applying a filtering process to the generated questions (code in the `question_filtering` folder).
32+
- Generating responses (code in the `data_generation` folder) and applying a reward filtering strategy (code in the `reward_filtering` folder).
33+
3. For instruction-tuning and evaluation, we directly use the DART-Math framework.
3234
File renamed without changes.

src/data_generation/gen.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ def get_args():
8181
elif "deepseek" in args.qry_prompt_type:
8282
pre_query_template = "<|begin▁of▁sentence|>User: "
8383
stop_tokens = ["<|begin▁of▁sentence|>", "<|end▁of▁sentence|>"]
84+
elif "qwen2.5-code" in args.qry_prompt_type:
85+
pre_query_template = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n"
86+
stop_tokens = ["<|im_start|>", "<|im_end|>", "<|endoftext|>"]
8487
else:
8588
raise NotImplementedError(
8689
f"Query prompt type {args.qry_prompt_type} is not implemented"
@@ -155,6 +158,13 @@ def flatten_batch_and_strip(line_data):
155158
"and put your final answer within \\boxed{{}}.\n\nAssistant:"
156159
)
157160
stop_tokens = ["<|begin▁of▁sentence|>", "<|end▁of▁sentence|>"]
161+
elif "qwen2.5-code" in args.res_prompt_type:
162+
res_generation_template = (
163+
"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n"
164+
"<|im_start|>user\n{input}<|im_end|>\n"
165+
"<|im_start|>assistant\n"
166+
)
167+
stop_tokens = ["<|im_start|>", "<|im_end|>", "<|endoftext|>"]
158168
else:
159169
raise NotImplementedError(
160170
f"Response prompt type {args.res_prompt_type} is not implemented"
@@ -203,7 +213,9 @@ def filter_data(line_data):
203213
has_answer = "boxed" in response or "he answer is" in response or "final answer is" in response
204214
return has_answer
205215

206-
dataset = dataset.map(strip_data, concurrency=4).filter(filter_data, concurrency=4)
216+
dataset = dataset.map(strip_data, concurrency=4)
217+
if "math" in args.res_prompt_type:
218+
dataset = dataset.filter(filter_data, concurrency=4)
207219

208220
res_gen_output_path = os.path.join(
209221
args.output_folder,

src/train_question_generator/qft_train/train.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ def formatting_prompts_func(example):
128128
text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{example['query'].strip()}<|im_end|>"
129129
elif script_args.prompt_type == "deepseek-math":
130130
text = f"User: {example['query'].strip()}\n\n<|end▁of▁sentence|>"
131+
elif script_args.prompt_type == "deepseek-code":
132+
text = f"You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n" \
133+
f"### Instruction:\n{example['query'].strip()}\n### Response:\n"
134+
elif script_args.prompt_type == "qwen2.5-code":
135+
text = f"<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n{example['query'].strip()}<|im_end|>"
131136
else:
132137
raise NotImplementedError(
133138
f"Prompt type {script_args.prompt_type} not implemented."
@@ -140,7 +145,7 @@ def formatting_prompts_func(example):
140145
train_dataset = dataset["train"]
141146
eval_dataset = dataset["test"] if "test" in dataset else None
142147
if script_args.max_training_samples > 0:
143-
train_dataset = train_dataset.select(range(script_args.max_training_samples))
148+
train_dataset = train_dataset.shuffle(seed=42).select(range(script_args.max_training_samples))
144149

145150

146151
# formatting_prompts_func
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Step 1: QFT
2+
ACCELERATE_LOG_LEVEL=info accelerate launch \
3+
--config_file ./zero3.yaml \
4+
--main_process_port 29600 \
5+
qft_train/train.py \
6+
--model_path deepseek-ai/deepseek-coder-6.7b-instruct \
7+
--dataset_path /path/to/CodeFeedback-Filtered-Instruction \
8+
--prompt_type deepseek-code \
9+
--num_train_epochs 1 \
10+
--gradient_checkpointing false \
11+
--max_length 256 \
12+
--output_dir models/Deepseek-Coder-7B-QFT \
13+
--per_device_train_batch_size 1 \
14+
--per_device_eval_batch_size 1 \
15+
--gradient_accumulation_steps 4 \
16+
17+
Step 2: QPO
18+
ACCELERATE_LOG_LEVEL=info accelerate launch \
19+
--config_file ./zero3.yaml \
20+
--main_process_port 29601 \
21+
train.py \
22+
--model_path models/Deepseek-Coder-7B-QFT \
23+
--ref_model models/Deepseek-Coder-7B-QFT \
24+
--dataset_path /path/to/qpo_data \
25+
--prompt_type deepseek-code \
26+
--run_name deepseek-code-qgen-sft-dpo \
27+
--learning_rate 5e-7 \
28+
--lr_scheduler_type cosine \
29+
--loss_type sigmoid \
30+
--warmup_steps 20 \
31+
--num_train_epochs 1 \
32+
--gradient_checkpointing true \
33+
--max_length 1024 \
34+
--output_dir models/Deepseek-Coder-7B-QGen \
35+
--per_device_train_batch_size 8 \
36+
--per_device_eval_batch_size 8 \
37+
--gradient_accumulation_steps 2 \
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Step 1: QFT
2+
ACCELERATE_LOG_LEVEL=info accelerate launch \
3+
--config_file scripts/zero3.yaml \
4+
--main_process_port 29500 \
5+
qft_train/train.py \
6+
--model_path Qwen/Qwen2.5-Coder-7B-Instruct \
7+
--dataset_path /path/to/CodeFeedback-Filtered-Instruction \
8+
--prompt_type qwen2.5-code \
9+
--num_train_epochs 1 \
10+
--gradient_checkpointing false \
11+
--max_length 256 \
12+
--max_training_samples 20000 \
13+
--output_dir models/Qwen2.5-Coder-7B-QFT \
14+
--per_device_train_batch_size 1 \
15+
--per_device_eval_batch_size 1 \
16+
--gradient_accumulation_steps 4 \
17+
18+
# Step 2: QPO
19+
ACCELERATE_LOG_LEVEL=info accelerate launch \
20+
--config_file ./zero3.yaml \
21+
--main_process_port 29051 \
22+
train.py \
23+
--model_path models/Qwen2.5-Coder-7B-QFT \
24+
--ref_model models/Qwen2.5-Coder-7B-QFT \
25+
--dataset_path /path/to/qpo_data \
26+
--prompt_type qwen2.5-code \
27+
--run_name qwen2-code-qgen-sft-dpo \
28+
--learning_rate 5e-7 \
29+
--lr_scheduler_type cosine \
30+
--loss_type sigmoid \
31+
--warmup_steps 20 \
32+
--num_train_epochs 1 \
33+
--gradient_checkpointing true \
34+
--max_length 1024 \
35+
--output_dir models/Qwen2-Coder-7B-QGen \
36+
--per_device_train_batch_size 8 \
37+
--per_device_eval_batch_size 8 \
38+
--gradient_accumulation_steps 2 \

src/train_question_generator/scripts/run_qwen2_qft.sh src/train_question_generator/scripts/run_qwen2math_qft.sh

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ train.py \
1414
--per_device_eval_batch_size 1 \
1515
--gradient_accumulation_steps 4 \
1616

17+
# Step 2: QPO
1718
ACCELERATE_LOG_LEVEL=info accelerate launch \
1819
--config_file ./zero3.yaml \
1920
--main_process_port 29051 \

0 commit comments

Comments
 (0)