Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update inference #1

Merged
merged 6 commits into from
Oct 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,20 @@ https://github.com/kojima-takeshi188/zero_shot_cot/tree/main/log

## Quick Start

See ```try_cot.ipynb```

## Instructions

Construct Demos:

```
python run_demo.py --task multiarith --pred_file log/multiarith_zero_shot_cot.log --demo_save_dir demos/multiarith
```

Run inference:

```
python run_demo_log.py
python run_inference.py --dataset multiarith --demo_path demos/multiarith --output_dir experiment/multiarith
```

## Citing Auto-CoT
Expand Down
93 changes: 93 additions & 0 deletions api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import argparse
from utils import *

def cot(method, question):
args = parse_arguments()
decoder = Decoder()

args.method = method
if args.method != "zero_shot_cot":
if args.method == "auto_cot":
args.demo_path = "demos/multiarith_auto"
else:
args.demo_path = "demos/multiarith_manual"
demo = create_demo_text(args, cot_flag=True)
else:
demo = None

x = "Q: " + question + "\n" + "A:"
print('*****************************')
print("Test Question:")
print(question)
print('*****************************')

if args.method == "zero_shot":
x = x + " " + args.direct_answer_trigger_for_zeroshot
elif args.method == "zero_shot_cot":
x = x + " " + args.cot_trigger
elif args.method == "manual_cot":
x = demo + x
elif args.method == "auto_cot":
x = demo + x + " " + args.cot_trigger
else:
raise ValueError("method is not properly defined ...")

print("Prompted Input:")
print(x.replace("\n\n", "\n").strip())
print('*****************************')

max_length = args.max_length_cot if "cot" in args.method else args.max_length_direct
z = decoder.decode(args, x, max_length)
z = z.replace("\n\n", "\n").replace("\n", "").strip()
if args.method == "zero_shot_cot":
z2 = x + z + " " + args.direct_answer_trigger_for_zeroshot_cot
max_length = args.max_length_direct
pred = decoder.decode(args, z2, max_length)
print("Output:")
print(z + " " + args.direct_answer_trigger_for_zeroshot_cot + " " + pred)
print('*****************************')
else:
pred = z
print("Output:")
print(pred)
print('*****************************')

def parse_arguments():
parser = argparse.ArgumentParser(description="Zero-shot-CoT")

parser.add_argument("--max_num_worker", type=int, default=0, help="maximum number of workers for dataloader")
parser.add_argument(
"--model", type=str, default="gpt3-xl", help="model used for decoding. Note that 'gpt3' are the smallest models."
)
parser.add_argument(
"--method", type=str, default="auto_cot", choices=["zero_shot", "zero_shot_cot", "few_shot", "few_shot_cot", "auto_cot"], help="method"
)
parser.add_argument(
"--cot_trigger_no", type=int, default=1, help="A trigger sentence that elicits a model to execute chain of thought"
)
parser.add_argument(
"--max_length_cot", type=int, default=256, help="maximum length of output tokens by model for reasoning extraction"
)
parser.add_argument(
"--max_length_direct", type=int, default=32, help="maximum length of output tokens by model for answer extraction"
)
parser.add_argument(
"--limit_dataset_size", type=int, default=0, help="whether to limit test dataset size. if 0, the dataset size is unlimited and we use all the samples in the dataset for testing."
)
parser.add_argument(
"--api_time_interval", type=float, default=1.0, help=""
)
parser.add_argument(
"--temperature", type=float, default=0, help=""
)
parser.add_argument(
"--log_dir", type=str, default="./log/", help="log directory"
)
args = parser.parse_args()

args.direct_answer_trigger_for_fewshot = "The answer is"
args.direct_answer_trigger_for_zeroshot = "The answer is"
args.direct_answer_trigger_for_zeroshot_cot = "The answer is"
args.cot_trigger = "Let's think step by step."

return args
52 changes: 52 additions & 0 deletions demos/multiarith_auto
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
{
"demo": [
{
"question": "Q: Wendy uploaded 45 pictures to Facebook. She put 27 pics into one album and put the rest into 9 different albums. How many pictures were in each album?\nA:",
"rationale": "Let's think step by step. First, we know that Wendy uploaded 45 pictures in total. Second, we know that Wendy put 27 pictures into one album. That means that Wendy put the remaining 18 pictures into 9 different albums. That means that each album would have 2 pictures.",
"pred_ans": "2",
"gold_ans": "2"
},
{
"question": "Q: For Halloween Katie and her sister combined the candy they received. Katie had 8 pieces of candy while her sister had 23. If they ate 8 pieces the first night, how many pieces do they have left?\nA:",
"rationale": "Let's think step by step. Katie and her sister have a total of 8 + 23 = 31 pieces of candy. If they eat 8 pieces the first night, they have 31 - 8 = 23 pieces left.",
"pred_ans": "23",
"gold_ans": "23"
},
{
"question": "Q: Bianca was organizing her book case making sure each of the shelves had exactly 8 books on it. If she had 5 shelves of mystery books and 4 shelves of picture books, how many books did she have total?\nA:",
"rationale": "Let's think step by step. There are 5 shelves of mystery books. Each shelf has 8 books. So that's 40 mystery books. There are 4 shelves of picture books. Each shelf has 8 books. So that's 32 picture books. 40 + 32 = 72 books.",
"pred_ans": "72",
"gold_ans": "72"
},
{
"question": "Q: A pet store had 78 puppies. In one day they sold 30 of them and put the rest into cages with 8 in each cage. How many cages did they use?\nA:",
"rationale": "Let's think step by step. There are 78 puppies. 30 are sold, so that means there are 48 left. 48 divided by 8 is 6, so that means there are 6 cages with 8 puppies in each.",
"pred_ans": "6",
"gold_ans": "6"
},
{
"question": "Q: A waiter had 14 customers to wait on. If 3 customers left and he got another 39 customers, how many customers would he have?\nA:",
"rationale": "Let's think step by step. The waiter had 14 customers to wait on. If 3 customers left, that means he would have 11 customers left. If he got another 39 customers, that means he would have 50 customers in total.",
"pred_ans": "50",
"gold_ans": "50"
},
{
"question": "Q: A trivia team had 7 members total, but during a game 2 members didn't show up. If each member that did show up scored 4 points, how many points were scored total?\nA:",
"rationale": "Let's think step by step. There were 7 members on the team, but 2 members didn't show up. That means that there were 5 members that did show up. Each member that showed up scored 4 points. So if 5 members each scored 4 points, then the total number of points scored would be 5*4=20.",
"pred_ans": "20",
"gold_ans": "20"
},
{
"question": "Q: Gwen had 18 math problems and 11 science problems for homework. If she finished 24 of the problems at school, how many problems did she have to do for homework?\nA:",
"rationale": "Let's think step by step. Gwen had 18 math problems and 11 science problems for homework. That means she had a total of 29 problems for homework. If she finished 24 of the problems at school, that means she had 5 problems left to do for homework.",
"pred_ans": "5",
"gold_ans": "5"
},
{
"question": "Q: Mike made 69 dollars mowing lawns over the summer. If he spent 24 dollars buying new mower blades, how many 5 dollar games could he buy with the money he had left?\nA:",
"rationale": "Let's think step by step. Mike made $69 from mowing lawns. He spent $24 on new mower blades. That means he has $45 left. Each game costs $5, so he could buy 9 games.",
"pred_ans": "9",
"gold_ans": "9"
}
]
}
44 changes: 44 additions & 0 deletions demos/multiarith_manual
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
{
"demo": [
{
"question": "Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\nA:",
"rationale": "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6.",
"pred_ans": "6"
},
{
"question": "Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\nA:",
"rationale": "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.",
"pred_ans": "5"
},
{
"question": "Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\nA:",
"rationale": "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39.",
"pred_ans": "39"
},
{
"question": "Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\nA:",
"rationale": "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8.",
"pred_ans": "8"
},
{
"question": "Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?\nA:",
"rationale": "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9.",
"pred_ans": "9"
},
{
"question": "Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?\nA:",
"rationale": "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29.",
"pred_ans": "29"
},
{
"question": "Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?\nA:",
"rationale": "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls.",
"pred_ans": "33"
},
{
"question": "Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?\nA:",
"rationale": "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8.",
"pred_ans": "8"
}
]
}
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
sklearn
matplotlib
sentence-transformers
jupyter
4 changes: 1 addition & 3 deletions run_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def parse_arguments():
help="use the reasoning chains generated by zero-shot-cot."
)
parser.add_argument(
"--demo_save_dir", type=str, default="demo/multiarith", help="where to save the contructed demonstrations"
"--demo_save_dir", type=str, default="demos/multiarith", help="where to save the contructed demonstrations"
)
parser.add_argument("--random_seed", type=int, default=192, help="random seed")
parser.add_argument(
Expand Down Expand Up @@ -110,7 +110,6 @@ def main():
clustered_idx[cluster_id].append(sentence_id)

demos = []
curr_wrong = 0

for i in range(len(clustered_dists)):
print("Cluster ", i+1)
Expand Down Expand Up @@ -153,7 +152,6 @@ def main():

with open(args.demo_save_dir, 'w', encoding="utf-8") as write_f:
json.dump(demos, write_f, indent=4, ensure_ascii=False)
print(curr_wrong)

y_km = clustering_model.fit_predict(corpus_embeddings)
pca_model = PCA(n_components=2, random_state=args.random_seed)
Expand Down
Loading