diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..c5ed33e Binary files /dev/null and b/.DS_Store differ diff --git a/README.md b/README.md index 6062822..0e38e90 100755 --- a/README.md +++ b/README.md @@ -30,8 +30,20 @@ https://github.com/kojima-takeshi188/zero_shot_cot/tree/main/log ## Quick Start +See ```try_cot.ipynb``` + +## Instructions + +Construct Demos: + +``` +python run_demo.py --task multiarith --pred_file log/multiarith_zero_shot_cot.log --demo_save_dir demos/multiarith +``` + +Run inference: + ``` -python run_demo_log.py +python run_inference.py --dataset multiarith --demo_path demos/multiarith --output_dir experiment/multiarith ``` ## Citing Auto-CoT diff --git a/api.py b/api.py new file mode 100644 index 0000000..4990cae --- /dev/null +++ b/api.py @@ -0,0 +1,93 @@ +import argparse +from utils import * + +def cot(method, question): + args = parse_arguments() + decoder = Decoder() + + args.method = method + if args.method != "zero_shot_cot": + if args.method == "auto_cot": + args.demo_path = "demos/multiarith_auto" + else: + args.demo_path = "demos/multiarith_manual" + demo = create_demo_text(args, cot_flag=True) + else: + demo = None + + x = "Q: " + question + "\n" + "A:" + print('*****************************') + print("Test Question:") + print(question) + print('*****************************') + + if args.method == "zero_shot": + x = x + " " + args.direct_answer_trigger_for_zeroshot + elif args.method == "zero_shot_cot": + x = x + " " + args.cot_trigger + elif args.method == "manual_cot": + x = demo + x + elif args.method == "auto_cot": + x = demo + x + " " + args.cot_trigger + else: + raise ValueError("method is not properly defined ...") + + print("Prompted Input:") + print(x.replace("\n\n", "\n").strip()) + print('*****************************') + + max_length = args.max_length_cot if "cot" in args.method else args.max_length_direct + z = decoder.decode(args, x, max_length) + z = z.replace("\n\n", "\n").replace("\n", "").strip() + if args.method == "zero_shot_cot": + z2 = x + z + " " + args.direct_answer_trigger_for_zeroshot_cot + max_length = args.max_length_direct + pred = decoder.decode(args, z2, max_length) + print("Output:") + print(z + " " + args.direct_answer_trigger_for_zeroshot_cot + " " + pred) + print('*****************************') + else: + pred = z + print("Output:") + print(pred) + print('*****************************') + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Zero-shot-CoT") + + parser.add_argument("--max_num_worker", type=int, default=0, help="maximum number of workers for dataloader") + parser.add_argument( + "--model", type=str, default="gpt3-xl", help="model used for decoding. Note that 'gpt3' are the smallest models." + ) + parser.add_argument( + "--method", type=str, default="auto_cot", choices=["zero_shot", "zero_shot_cot", "few_shot", "few_shot_cot", "auto_cot"], help="method" + ) + parser.add_argument( + "--cot_trigger_no", type=int, default=1, help="A trigger sentence that elicits a model to execute chain of thought" + ) + parser.add_argument( + "--max_length_cot", type=int, default=256, help="maximum length of output tokens by model for reasoning extraction" + ) + parser.add_argument( + "--max_length_direct", type=int, default=32, help="maximum length of output tokens by model for answer extraction" + ) + parser.add_argument( + "--limit_dataset_size", type=int, default=0, help="whether to limit test dataset size. if 0, the dataset size is unlimited and we use all the samples in the dataset for testing." + ) + parser.add_argument( + "--api_time_interval", type=float, default=1.0, help="" + ) + parser.add_argument( + "--temperature", type=float, default=0, help="" + ) + parser.add_argument( + "--log_dir", type=str, default="./log/", help="log directory" + ) + args = parser.parse_args() + + args.direct_answer_trigger_for_fewshot = "The answer is" + args.direct_answer_trigger_for_zeroshot = "The answer is" + args.direct_answer_trigger_for_zeroshot_cot = "The answer is" + args.cot_trigger = "Let's think step by step." + + return args \ No newline at end of file diff --git a/demos/multiarith_auto b/demos/multiarith_auto new file mode 100644 index 0000000..3367092 --- /dev/null +++ b/demos/multiarith_auto @@ -0,0 +1,52 @@ +{ + "demo": [ + { + "question": "Q: Wendy uploaded 45 pictures to Facebook. She put 27 pics into one album and put the rest into 9 different albums. How many pictures were in each album?\nA:", + "rationale": "Let's think step by step. First, we know that Wendy uploaded 45 pictures in total. Second, we know that Wendy put 27 pictures into one album. That means that Wendy put the remaining 18 pictures into 9 different albums. That means that each album would have 2 pictures.", + "pred_ans": "2", + "gold_ans": "2" + }, + { + "question": "Q: For Halloween Katie and her sister combined the candy they received. Katie had 8 pieces of candy while her sister had 23. If they ate 8 pieces the first night, how many pieces do they have left?\nA:", + "rationale": "Let's think step by step. Katie and her sister have a total of 8 + 23 = 31 pieces of candy. If they eat 8 pieces the first night, they have 31 - 8 = 23 pieces left.", + "pred_ans": "23", + "gold_ans": "23" + }, + { + "question": "Q: Bianca was organizing her book case making sure each of the shelves had exactly 8 books on it. If she had 5 shelves of mystery books and 4 shelves of picture books, how many books did she have total?\nA:", + "rationale": "Let's think step by step. There are 5 shelves of mystery books. Each shelf has 8 books. So that's 40 mystery books. There are 4 shelves of picture books. Each shelf has 8 books. So that's 32 picture books. 40 + 32 = 72 books.", + "pred_ans": "72", + "gold_ans": "72" + }, + { + "question": "Q: A pet store had 78 puppies. In one day they sold 30 of them and put the rest into cages with 8 in each cage. How many cages did they use?\nA:", + "rationale": "Let's think step by step. There are 78 puppies. 30 are sold, so that means there are 48 left. 48 divided by 8 is 6, so that means there are 6 cages with 8 puppies in each.", + "pred_ans": "6", + "gold_ans": "6" + }, + { + "question": "Q: A waiter had 14 customers to wait on. If 3 customers left and he got another 39 customers, how many customers would he have?\nA:", + "rationale": "Let's think step by step. The waiter had 14 customers to wait on. If 3 customers left, that means he would have 11 customers left. If he got another 39 customers, that means he would have 50 customers in total.", + "pred_ans": "50", + "gold_ans": "50" + }, + { + "question": "Q: A trivia team had 7 members total, but during a game 2 members didn't show up. If each member that did show up scored 4 points, how many points were scored total?\nA:", + "rationale": "Let's think step by step. There were 7 members on the team, but 2 members didn't show up. That means that there were 5 members that did show up. Each member that showed up scored 4 points. So if 5 members each scored 4 points, then the total number of points scored would be 5*4=20.", + "pred_ans": "20", + "gold_ans": "20" + }, + { + "question": "Q: Gwen had 18 math problems and 11 science problems for homework. If she finished 24 of the problems at school, how many problems did she have to do for homework?\nA:", + "rationale": "Let's think step by step. Gwen had 18 math problems and 11 science problems for homework. That means she had a total of 29 problems for homework. If she finished 24 of the problems at school, that means she had 5 problems left to do for homework.", + "pred_ans": "5", + "gold_ans": "5" + }, + { + "question": "Q: Mike made 69 dollars mowing lawns over the summer. If he spent 24 dollars buying new mower blades, how many 5 dollar games could he buy with the money he had left?\nA:", + "rationale": "Let's think step by step. Mike made $69 from mowing lawns. He spent $24 on new mower blades. That means he has $45 left. Each game costs $5, so he could buy 9 games.", + "pred_ans": "9", + "gold_ans": "9" + } + ] +} \ No newline at end of file diff --git a/demos/multiarith_manual b/demos/multiarith_manual new file mode 100644 index 0000000..e79f8f3 --- /dev/null +++ b/demos/multiarith_manual @@ -0,0 +1,44 @@ +{ + "demo": [ + { + "question": "Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\nA:", + "rationale": "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6.", + "pred_ans": "6" + }, + { + "question": "Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\nA:", + "rationale": "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.", + "pred_ans": "5" + }, + { + "question": "Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\nA:", + "rationale": "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39.", + "pred_ans": "39" + }, + { + "question": "Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\nA:", + "rationale": "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8.", + "pred_ans": "8" + }, + { + "question": "Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\nA:", + "rationale": "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9.", + "pred_ans": "9" + }, + { + "question": "Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?\nA:", + "rationale": "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29.", + "pred_ans": "29" + }, + { + "question": "Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\nA:", + "rationale": "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls.", + "pred_ans": "33" + }, + { + "question": "Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\nA:", + "rationale": "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8.", + "pred_ans": "8" + } + ] +} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 459061c..2b14069 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ sklearn matplotlib sentence-transformers +jupyter diff --git a/run_demo.py b/run_demo.py index a9021a2..ee773fd 100755 --- a/run_demo.py +++ b/run_demo.py @@ -23,7 +23,7 @@ def parse_arguments(): help="use the reasoning chains generated by zero-shot-cot." ) parser.add_argument( - "--demo_save_dir", type=str, default="demo/multiarith", help="where to save the contructed demonstrations" + "--demo_save_dir", type=str, default="demos/multiarith", help="where to save the contructed demonstrations" ) parser.add_argument("--random_seed", type=int, default=192, help="random seed") parser.add_argument( @@ -110,7 +110,6 @@ def main(): clustered_idx[cluster_id].append(sentence_id) demos = [] - curr_wrong = 0 for i in range(len(clustered_dists)): print("Cluster ", i+1) @@ -153,7 +152,6 @@ def main(): with open(args.demo_save_dir, 'w', encoding="utf-8") as write_f: json.dump(demos, write_f, indent=4, ensure_ascii=False) - print(curr_wrong) y_km = clustering_model.fit_predict(corpus_embeddings) pca_model = PCA(n_components=2, random_state=args.random_seed) diff --git a/run_inference.py b/run_inference.py new file mode 100644 index 0000000..2d4f74c --- /dev/null +++ b/run_inference.py @@ -0,0 +1,211 @@ +''' +Adapted from https://github.com/kojima-takeshi188/zero_shot_cot +''' + +import argparse +from utils import * + +def main(): + args = parse_arguments() + print('*****************************') + print(args) + print('*****************************') + + fix_seed(args.random_seed) + + print("OPENAI_API_KEY:") + print(os.getenv("OPENAI_API_KEY")) + + # Initialize decoder class (load model and tokenizer) ... + decoder = Decoder() + + print("setup data loader ...") + dataloader = setup_data_loader(args) + print_now() + + if args.method == "few_shot": + demo = create_demo_text(args, cot_flag=False) + elif args.method == "few_shot_cot" or args.method == "auto_cot": + demo = create_demo_text(args, cot_flag=True) + else: + pass + + total = 0 + correct_list = [] + with open(args.output_dir, "a") as wp: + + for i, data in enumerate(dataloader): + if i < args.resume_id - 1: + # if i < 297: + continue + output_line = {} + + print('*************************') + print("{}st data".format(i+1)) + + # Prepare question template ... + x, y = data + x = "Q: " + x[0] + "\n" + "A:" + y = y[0].strip() + + # print(x, y) + + output_line["question"] = x + output_line["gold_ans"] = y + + if args.method == "zero_shot": + x = x + " " + args.direct_answer_trigger_for_zeroshot + elif args.method == "zero_shot_cot": + x = x + " " + args.cot_trigger + elif args.method == "few_shot": + x = demo + x + elif args.method == "few_shot_cot": + x = demo + x + elif args.method == "auto_cot": + x = demo + x + " " + args.cot_trigger + else: + raise ValueError("method is not properly defined ...") + + # Answer experiment by generating text ... + max_length = args.max_length_cot if "cot" in args.method else args.max_length_direct + z = decoder.decode(args, x, max_length) + + output_line["rationale"] = z + + # Answer extraction for zero-shot-cot ... + if args.method == "zero_shot_cot": + z2 = x + z + " " + args.direct_answer_trigger_for_zeroshot_cot + max_length = args.max_length_direct + pred = decoder.decode(args, z2, max_length) + print(z2 + pred) + else: + pred = z + print(x + pred) + + # Clensing of predicted answer ... + pred = answer_cleansing(args, pred) + + + output_line["pred_ans"] = pred + output_line["wrap_que"] = x + + output_json = json.dumps(output_line) + wp.write(output_json + '\n') + + # Choose the most frequent answer from the list ... + print("pred : {}".format(pred)) + print("GT : " + y) + print('*************************') + + # Checking answer ... + correct = (np.array([pred]) == np.array([y])).sum().item() + correct_list.append(correct) + total += 1 #np.array([y]).size(0) + + if (args.limit_dataset_size != 0) and ((i+1) >= args.limit_dataset_size): + break + #raise ValueError("Stop !!") + + # Calculate accuracy ... + accuracy = (sum(correct_list) * 1.0 / total) * 100 + print("accuracy : {}".format(accuracy)) + +def parse_arguments(): + parser = argparse.ArgumentParser(description="Zero-shot-CoT") + + parser.add_argument("--random_seed", type=int, default=1, help="random seed") + parser.add_argument( + "--dataset", type=str, default="multiarith", choices=["aqua", "gsm8k", "commonsensqa", "addsub", "multiarith", "strategyqa", "svamp", "singleeq", "coin_flip", "last_letters"], help="dataset used for experiment" + ) + parser.add_argument( + "--demo_path", type=str, default="demos/multiarith", help="dataset used for experiment" + ) + parser.add_argument( + "--resume_id", type=int, default=0, help="whether to limit test dataset size. if 0, the dataset size is unlimited and we use all the samples in the dataset for testing." + ) + parser.add_argument("--minibatch_size", type=int, default=1, choices=[1], help="minibatch size should be 1 because GPT-3 API takes only 1 input for each request") + + parser.add_argument("--max_num_worker", type=int, default=0, help="maximum number of workers for dataloader") + + parser.add_argument( + "--model", type=str, default="gpt3-xl", choices=["gpt3", "gpt3-medium", "gpt3-large", "gpt3-xl", "code-davinci-002"], help="model used for decoding. Note that 'gpt3' are the smallest models." + ) + + parser.add_argument( + "--method", type=str, default="auto_cot", choices=["zero_shot", "zero_shot_cot", "few_shot", "few_shot_cot", "auto_cot"], help="method" + ) + parser.add_argument( + "--output_dir", type=str, default="experiment/multiarith", help="output directory" + ) + parser.add_argument( + "--max_length_cot", type=int, default=256, help="maximum length of output tokens by model for reasoning extraction" + ) + parser.add_argument( + "--max_length_direct", type=int, default=32, help="maximum length of output tokens by model for answer extraction" + ) + parser.add_argument( + "--limit_dataset_size", type=int, default=0, help="whether to limit test dataset size. if 0, the dataset size is unlimited and we use all the samples in the dataset for testing." + ) + parser.add_argument( + "--api_time_interval", type=float, default=1.0, help="" + ) + parser.add_argument( + "--temperature", type=float, default=0, help="" + ) + parser.add_argument( + "--log_dir", type=str, default="./log/", help="log directory" + ) + + args = parser.parse_args() + + if args.dataset == "aqua": + args.dataset_path = "./dataset/AQuA/test.json" + args.direct_answer_trigger = "\nTherefore, among A through E, the answer is" + elif args.dataset == "gsm8k": + args.dataset_path = "./dataset/grade-school-math/test.jsonl" + args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is" + elif args.dataset == "commonsensqa": + args.dataset_path = "./dataset/CommonsenseQA/dev_rand_split.jsonl" + args.direct_answer_trigger = "\nTherefore, among A through E, the answer is" + args.plausible_answer_trigger = "Choose the most plausible answer from among choices A through E." + elif args.dataset == "addsub": + args.dataset_path = "./dataset/AddSub/AddSub.json" + args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is" + elif args.dataset == "multiarith": + args.dataset_path = "./dataset/MultiArith/MultiArith.json" + args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is" + elif args.dataset == "strategyqa": + args.dataset_path = "./dataset/StrategyQA/task.json" + args.direct_answer_trigger = "\nTherefore, the answer (Yes or No) is" + elif args.dataset == "svamp": + args.dataset_path = "./dataset/SVAMP/SVAMP.json" + args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is" + elif args.dataset == "singleeq": + args.dataset_path = "./dataset/SingleEq/questions.json" + args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is" + elif args.dataset == "bigbench_date": + args.dataset_path = "./dataset/Bigbench_Date/task.json" + args.direct_answer_trigger = "\nTherefore, among A through F, the answer is" + elif args.dataset == "object_tracking": + args.dataset_path = "./dataset/Bigbench_object_tracking/task.json" + args.direct_answer_trigger = "\nTherefore, among A through C, the answer is" + elif args.dataset == "coin_flip": + args.dataset_path = "./dataset/coin_flip/coin_flip.json" + args.direct_answer_trigger = "\nTherefore, the answer (Yes or No) is" + elif args.dataset == "last_letters": + args.dataset_path = "./dataset/last_letters/last_letters.json" + args.direct_answer_trigger = "\nTherefore, the answer is" + else: + raise ValueError("dataset is not properly defined ...") + + # "Therefore, the answer ..." -> "The answer ..." + trigger = args.direct_answer_trigger.replace("\nTherefore, ", "") + args.direct_answer_trigger_for_zeroshot = trigger[0].upper() + trigger[1:] + args.direct_answer_trigger_for_zeroshot_cot = args.direct_answer_trigger + args.direct_answer_trigger_for_fewshot = "The answer is" + args.cot_trigger = "Let's think step by step." + + return args + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/try_cot.ipynb b/try_cot.ipynb new file mode 100644 index 0000000..b6d147c --- /dev/null +++ b/try_cot.ipynb @@ -0,0 +1,215 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Example: Zero-Shot\n", + "*****************************\n", + "Test Question:\n", + "There were 10 friends playing a video game online when 7 players quit. If each player left had 8 lives, how many lives did they have total?\n", + "*****************************\n", + "Prompted Input:\n", + "Q: There were 10 friends playing a video game online when 7 players quit. If each player left had 8 lives, how many lives did they have total?\n", + "A: The answer is\n", + "*****************************\n", + "Output:\n", + "80.\n", + "*****************************\n" + ] + } + ], + "source": [ + "import sys\n", + "sys.argv=['']\n", + "del sys\n", + "from api import cot\n", + "question = \"There were 10 friends playing a video game online when 7 players quit. If each player left had 8 lives, how many lives did they have total?\"\n", + "# for involving GPT-3, please add your openai-api key in utils.py (#Line 59)\n", + "# method = [\"zero_shot\", \"zero_shot_cot\", \"manual_cot\", \"auto_cot\"]\n", + "print(\"Example: Zero-Shot\")\n", + "cot(method=\"zero_shot\", question=question)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Example: Zero-Shot-CoT\n", + "2022/10/12 12:48:47\n", + "*****************************\n", + "Test Question:\n", + "There were 10 friends playing a video game online when 7 players quit. If each player left had 8 lives, how many lives did they have total?\n", + "*****************************\n", + "Prompted Input:\n", + "Q: There were 10 friends playing a video game online when 7 players quit. If each player left had 8 lives, how many lives did they have total?\n", + "A: Let's think step by step.\n", + "*****************************\n", + "Output:\n", + "There were 10 friends playing a video game online. This means that, at the start, there were 10 x 8 = 80 lives in total. Then, 7 players quit. This means that 7 x 8 = 56 lives were lost. Therefore, the total number of lives remaining is 80 - 56 = 24. The answer is 24.\n", + "*****************************\n" + ] + } + ], + "source": [ + "from api import cot\n", + "question = \"There were 10 friends playing a video game online when 7 players quit. If each player left had 8 lives, how many lives did they have total?\"\n", + "# for involving GPT-3, please add your openai-api key in utils.py (#Line 55)\n", + "# method = [\"zero_shot\", \"zero_shot_cot\", \"manual_cot\", \"auto_cot\"]\n", + "print(\"Example: Zero-Shot-CoT\")\n", + "cot(method=\"zero_shot_cot\", question=question)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Example: Manual-CoT\n", + "2022/10/12 12:39:57\n", + "*****************************\n", + "Test Question:\n", + "In a video game, each enemy defeated gives you 7 points. If a level has 11 enemies total and you destroy all but 8 of them, how many points would you earn?\n", + "*****************************\n", + "Prompted Input:\n", + "Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\n", + "A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.\n", + "Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\n", + "A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.\n", + "Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\n", + "A: Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The answer is 39.\n", + "Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\n", + "A: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. The answer is 8.\n", + "Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\n", + "A: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. The answer is 9.\n", + "Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?\n", + "A: There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The answer is 29.\n", + "Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\n", + "A: Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer is 33.\n", + "Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\n", + "A: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8.\n", + "Q: In a video game, each enemy defeated gives you 7 points. If a level has 11 enemies total and you destroy all but 8 of them, how many points would you earn?\n", + "A:\n", + "*****************************\n", + "Output:\n", + "You would earn 7 points for each enemy you defeat. So you would earn 7 * 11 = 77 points.\n", + "*****************************\n" + ] + } + ], + "source": [ + "from api import cot\n", + "question = \"In a video game, each enemy defeated gives you 7 points. If a level has 11 enemies total and you destroy all but 8 of them, how many points would you earn?\"\n", + "# for involving GPT-3, please add your openai-api key in utils.py (#Line 55)\n", + "# method = [\"zero_shot\", \"zero_shot_cot\", \"manual_cot\", \"auto_cot\"]\n", + "print(\"Example: Manual-CoT\")\n", + "cot(method=\"manual_cot\", question=question)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Example: Auto-CoT\n", + "2022/10/12 12:40:22\n", + "*****************************\n", + "Test Question:\n", + "In a video game, each enemy defeated gives you 7 points. If a level has 11 enemies total and you destroy all but 8 of them, how many points would you earn?\n", + "*****************************\n", + "Prompted Input:\n", + "Q: Wendy uploaded 45 pictures to Facebook. She put 27 pics into one album and put the rest into 9 different albums. How many pictures were in each album?\n", + "A: Let's think step by step. First, we know that Wendy uploaded 45 pictures in total. Second, we know that Wendy put 27 pictures into one album. That means that Wendy put the remaining 18 pictures into 9 different albums. That means that each album would have 2 pictures. The answer is 2.\n", + "Q: For Halloween Katie and her sister combined the candy they received. Katie had 8 pieces of candy while her sister had 23. If they ate 8 pieces the first night, how many pieces do they have left?\n", + "A: Let's think step by step. Katie and her sister have a total of 8 + 23 = 31 pieces of candy. If they eat 8 pieces the first night, they have 31 - 8 = 23 pieces left. The answer is 23.\n", + "Q: Bianca was organizing her book case making sure each of the shelves had exactly 8 books on it. If she had 5 shelves of mystery books and 4 shelves of picture books, how many books did she have total?\n", + "A: Let's think step by step. There are 5 shelves of mystery books. Each shelf has 8 books. So that's 40 mystery books. There are 4 shelves of picture books. Each shelf has 8 books. So that's 32 picture books. 40 + 32 = 72 books. The answer is 72.\n", + "Q: A pet store had 78 puppies. In one day they sold 30 of them and put the rest into cages with 8 in each cage. How many cages did they use?\n", + "A: Let's think step by step. There are 78 puppies. 30 are sold, so that means there are 48 left. 48 divided by 8 is 6, so that means there are 6 cages with 8 puppies in each. The answer is 6.\n", + "Q: A waiter had 14 customers to wait on. If 3 customers left and he got another 39 customers, how many customers would he have?\n", + "A: Let's think step by step. The waiter had 14 customers to wait on. If 3 customers left, that means he would have 11 customers left. If he got another 39 customers, that means he would have 50 customers in total. The answer is 50.\n", + "Q: A trivia team had 7 members total, but during a game 2 members didn't show up. If each member that did show up scored 4 points, how many points were scored total?\n", + "A: Let's think step by step. There were 7 members on the team, but 2 members didn't show up. That means that there were 5 members that did show up. Each member that showed up scored 4 points. So if 5 members each scored 4 points, then the total number of points scored would be 5*4=20. The answer is 20.\n", + "Q: Gwen had 18 math problems and 11 science problems for homework. If she finished 24 of the problems at school, how many problems did she have to do for homework?\n", + "A: Let's think step by step. Gwen had 18 math problems and 11 science problems for homework. That means she had a total of 29 problems for homework. If she finished 24 of the problems at school, that means she had 5 problems left to do for homework. The answer is 5.\n", + "Q: Mike made 69 dollars mowing lawns over the summer. If he spent 24 dollars buying new mower blades, how many 5 dollar games could he buy with the money he had left?\n", + "A: Let's think step by step. Mike made $69 from mowing lawns. He spent $24 on new mower blades. That means he has $45 left. Each game costs $5, so he could buy 9 games. The answer is 9.\n", + "Q: In a video game, each enemy defeated gives you 7 points. If a level has 11 enemies total and you destroy all but 8 of them, how many points would you earn?\n", + "A: Let's think step by step.\n", + "*****************************\n", + "Output:\n", + "Each enemy defeated gives you 7 points. If there are 11 enemies and you destroy all but 8 of them, that means you would have destroyed 3 enemies. 3 enemies * 7 points = 21 points. The answer is 21.\n", + "*****************************\n" + ] + } + ], + "source": [ + "from api import cot\n", + "question = \"In a video game, each enemy defeated gives you 7 points. If a level has 11 enemies total and you destroy all but 8 of them, how many points would you earn?\"\n", + "# for involving GPT-3, please add your openai-api key in utils.py (#Line 55)\n", + "# method = [\"zero_shot\", \"zero_shot_cot\", \"manual_cot\", \"auto_cot\"]\n", + "print(\"Example: Auto-CoT\")\n", + "cot(method=\"auto_cot\", question=question)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/utils.py b/utils.py old mode 100755 new mode 100644 index c86aa09..c1f0abe --- a/utils.py +++ b/utils.py @@ -1,11 +1,18 @@ +''' +Adapted from https://github.com/kojima-takeshi188/zero_shot_cot +''' + from statistics import mean from torch.utils.data import Dataset +import openai +import os import multiprocessing import json import numpy as np import torch import re import random +import time import datetime def shuffleDict(d): @@ -41,6 +48,67 @@ def print_now(return_flag=0): else: pass +# Sentence Generator (Decoder) for GPT-3 ... +def decoder_for_gpt3(args, input, max_length): + + # GPT-3 API allows each users execute the API within 60 times in a minute ... + # time.sleep(1) + time.sleep(args.api_time_interval) + + # https://beta.openai.com/account/api-keys + openai.api_key = "[Your OpenAI API Key]" + + # Specify engine ... + # Instruct GPT3 + if args.model == "gpt3": + engine = "text-ada-001" + elif args.model == "gpt3-medium": + engine = "text-babbage-001" + elif args.model == "gpt3-large": + engine = "text-curie-001" + elif args.model == "gpt3-xl": + engine = "text-davinci-002" + elif args.model == "text-davinci-001": + engine = "text-davinci-001" + elif args.model == "code-davinci-002": + engine = "code-davinci-002" + else: + raise ValueError("model is not properly defined ...") + + if ("few_shot" in args.method or "auto" in args.method) and engine == "code-davinci-002": + response = openai.Completion.create( + engine=engine, + prompt=input, + max_tokens=max_length, + temperature=args.temperature, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + stop=["\n"] + ) + else: + response = openai.Completion.create( + engine=engine, + prompt=input, + max_tokens=max_length, + temperature=args.temperature, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + stop=None + ) + + return response["choices"][0]["text"] + +class Decoder(): + def __init__(self): + # print_now() + pass + + def decode(self, args, input, max_length): + response = decoder_for_gpt3(args, input, max_length) + return response + def data_reader(args): questions = [] @@ -218,6 +286,7 @@ def seed_worker(worker_id): return dataloader +# ver 0.2 def answer_cleansing(args, pred, must_choice=False): print("pred_before : " + pred) @@ -296,4 +365,43 @@ def create_demo_text(args, cot_flag): args.direct_answer_trigger_for_fewshot + " " + y[i] + ".\n\n" else: demo_text += x[i] + " " + args.direct_answer_trigger_for_fewshot + " " + y[i] + ".\n\n" - return demo_text \ No newline at end of file + return demo_text + +def answer_cleansing_zero_shot(args, pred, must_choice=False): + pred = pred.strip() + if args.dataset in ("aqua", "commonsensqa"): + pred = re.findall(r'A|B|C|D|E', pred) + elif args.dataset == "bigbench_date": + pred = re.findall(r'A|B|C|D|E|F', pred) + elif args.dataset in ("object_tracking"): + pred = re.findall(r'A|B|C', pred) + elif args.dataset in ("gsm8k", "addsub", "multiarith", "svamp", "singleeq"): + if must_choice: + pred = re.findall(r'A|B|C|D', pred) + else: + pred = pred.replace(",", "") + pred = [s for s in re.findall(r'-?\d+\.?\d*', pred)] + elif args.dataset in ("strategyqa", "coin_flip"): + pred = pred.lower() + pred = re.sub("\"|\'|\n|\.|\s|\:|\,", " ", pred) + pred = pred.split(" ") + pred = [i for i in pred if i in ("yes", "no")] + elif args.dataset == "last_letters": + pred = re.sub("\"|\'|\n|\.|\s", "", pred) + pred = [pred] + else: + raise ValueError("dataset is not properly defined ...") + + # If there is no candidate in list, null is set. + if len(pred) == 0: + pred = "" + else: + # choose the first element in list ... + pred = pred[0] + + # (For arithmetic tasks) if a word ends with period, it will be omitted ... + if pred != "": + if pred[-1] == ".": + pred = pred[:-1] + + return pred