-
Notifications
You must be signed in to change notification settings - Fork 0
/
analyze-all-reasoning-steps.py
94 lines (79 loc) · 3.42 KB
/
analyze-all-reasoning-steps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import re
import os
from collections import defaultdict
import typer
from rich.console import Console
from rich.table import Table
from rich.align import Align
from constants import SEPARATOR
console = Console()
PERTURBATIONS = ["irrelevant", "relevant", "pathological", "combo"]
def analyze_by_reasoning_steps(data: str, model: str, perturbation: str):
"""
Analyze the data and group results based on the number of reasoning steps.
"""
datapoints = data.split(SEPARATOR)
if datapoints[-1] == "\n":
datapoints = datapoints[:-1]
grouped_results = defaultdict(
lambda: {"total": 0, "baseline_correct": 0, "experiment_correct": 0}
)
for datapoint in datapoints:
if datapoint == "\n" or not datapoint:
break
reasoning_steps_match = re.search(r"Reasoning Steps:\s*(\d+)", datapoint)
if reasoning_steps_match:
reasoning_steps = int(reasoning_steps_match.group(1))
correct_answer = re.findall(
r">>>> Extracted Correct Answer:\s*(.*?)\n", datapoint
)[0]
baseline_response = re.findall(
r">>>> Extracted Baseline Response:\s*(.*?)\n", datapoint
)[0]
experiment_response = re.findall(
r">>>> Extracted Experiment Response:\s*(.*?)\n", datapoint
)[0]
grouped_results[reasoning_steps]["total"] += 1
if baseline_response == correct_answer:
grouped_results[reasoning_steps]["baseline_correct"] += 1
if experiment_response == correct_answer:
grouped_results[reasoning_steps]["experiment_correct"] += 1
else:
raise Exception("Data point entry does not outline reasoning steps")
table = Table(
title=f"\n\n[bold]Results Breakdown by Reasoning Steps for {model}, {perturbation}[/bold]",
padding=(0, 2),
)
table.add_column("Reasoning Steps", style="cyan bold", justify="center")
table.add_column("Total Entries", style="magenta bold", justify="center")
table.add_column("Baseline Accuracy (%)", style="bold", justify="center")
table.add_column("Experiment Accuracy (%)", style="bold", justify="center")
for steps, results in sorted(grouped_results.items()):
total = results["total"]
baseline_accuracy = (results["baseline_correct"] / total) * 100
experiment_accuracy = (results["experiment_correct"] / total) * 100
table.add_row(
str(steps),
str(total),
f"{baseline_accuracy:.2f}",
f"{experiment_accuracy:.2f}",
)
console.print(Align.center(table))
def main():
base_path = "data/experiments"
for model in os.listdir(base_path):
model_path = os.path.join(base_path, model)
if os.path.isdir(model_path):
for perturbation in PERTURBATIONS:
pattern = f"cleaned-{perturbation}.txt"
for file in os.listdir(model_path):
if file == pattern:
file_path = os.path.join(model_path, file)
console.print(
f"\n[cyan bold]Processing file: {file_path}[/cyan bold]\n"
)
with open(file_path, "r") as f:
data = f.read()
analyze_by_reasoning_steps(data, model, perturbation)
if __name__ == "__main__":
typer.run(main)