@@ -18,31 +18,36 @@ def main():
1818 canonical_solution = f"```python\n { example ['canonical_solution' ]} \n ```"
1919 text = [{"role" : "user" , "message" : generate_prompt (example ["prompt" ], example ["test" ])}, {"role" : "assistant" , "message" : format_solution (canonical_solution , example ["prompt" ])}]
2020 texts .append (text )
21- print (text )
2221 ds [split ] = ds [split ].add_column (name = "text" , column = texts )
2322
24- # sample
25- all_samples = generate_predictions (
26- args .model_name_or_path , ds ["train" ], args .temperature , args .n
27- )
28- assert len (ds ["train" ]) == len (all_samples )
29-
30- # verify and construct the training set
31- all_traces , all_execution_results = execute_tests (ds ["train" ], all_samples )
32- passed_examples = []
33- for example , execution_results , samples in zip (
34- ds ["train" ], all_execution_results , all_samples
35- ):
36- for execution_result , sample in zip (execution_results , samples ):
37- # pytest exit code: https://docs.pytest.org/en/stable/reference/exit-codes.html
38- if execution_result == 0 :
39- example ["text" ] = [{"role" : "user" , "message" : generate_prompt (example ["prompt" ], example ["test" ])}, {"role" : "assistant" , "message" : format_solution (sample , example ["prompt" ])}]
40- passed_examples .append (example )
41- break
42- raw_datasets = DatasetDict ({"train" : Dataset .from_list (passed_examples ), "validation" : ds ["validation" ]})
43-
44- # train
45- train (raw_datasets , args .model_name_or_path , args )
23+ model_name = args .model_name_or_path
24+ for i in range (args .iteration ):
25+ # sample
26+ all_samples = generate_predictions (
27+ model_name , ds ["train" ], args .temperature , args .n
28+ )
29+ ds ["train" ].add_column (name = "sample" , column = all_samples ).to_json (f"{ args .output_dir } /data/samples-iter{ i } .json" )
30+ assert len (ds ["train" ]) == len (all_samples )
31+
32+ # verify and construct the training set
33+ all_traces , all_execution_results = execute_tests (ds ["train" ], all_samples )
34+ passed_examples = []
35+ for example , execution_results , samples in zip (
36+ ds ["train" ], all_execution_results , all_samples
37+ ):
38+ for execution_result , sample in zip (execution_results , samples ):
39+ # pytest exit code: https://docs.pytest.org/en/stable/reference/exit-codes.html
40+ if execution_result == 0 :
41+ example ["text" ] = [{"role" : "user" , "message" : generate_prompt (example ["prompt" ], example ["test" ])}, {"role" : "assistant" , "message" : format_solution (sample , example ["prompt" ])}]
42+ passed_examples .append (example )
43+ break
44+ raw_datasets = DatasetDict ({"train" : Dataset .from_list (passed_examples ), "validation" : ds ["validation" ]})
45+ raw_datasets ["train" ].to_json (f"{ args .output_dir } /data/verified-samples-iter{ i } .json" )
46+
47+ # train
48+ args .output_dir = f"{ args .output_dir } /models-iter{ i } "
49+ train (raw_datasets , model_name , args )
50+ model_name = args .output_dir
4651
4752
4853if __name__ == "__main__" :
0 commit comments