-
Notifications
You must be signed in to change notification settings - Fork 130
/
rstar.py
334 lines (291 loc) · 14.6 KB
/
rstar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
import math
import random
import logging
from typing import List, Dict, Any, Tuple
import re
import asyncio
import aiohttp
from concurrent.futures import ThreadPoolExecutor
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class Node:
def __init__(self, state: str, action: str, parent: 'Node' = None):
self.state = state
self.action = action
self.parent = parent
self.children: List[Node] = []
self.visits = 0
self.value = 0.0
class RStar:
def __init__(self, system: str, client, model: str, max_depth: int = 3, num_rollouts: int = 5, c: float = 1.4):
self.client = client
self.model_name = model
self.max_depth = max_depth
self.num_rollouts = num_rollouts
self.c = c
self.actions = ["A1", "A2", "A3", "A4", "A5"]
self.original_question = None
self.system = system
self.rstar_completion_tokens = 0
logger.debug(f"Initialized RStar with model: {model}, max_depth: {max_depth}, num_rollouts: {num_rollouts}")
async def generate_response_async(self, prompt: str) -> str:
return await asyncio.to_thread(self.generate_response, prompt)
async def expand_async(self, node: Node, action: str) -> Node:
prompt = self.create_prompt(node.state, action)
new_state = await self.generate_response_async(prompt)
child_node = Node(new_state, action, node)
node.children.append(child_node)
logger.debug(f"Expanded node with action: {action}")
return child_node
async def simulate_async(self, node: Node) -> float:
current_node = node
depth = 0
logger.debug("Starting simulation")
while depth < self.max_depth:
if not current_node.children:
action = random.choice(self.actions)
current_node = await self.expand_async(current_node, action)
else:
current_node = random.choice(current_node.children)
depth += 1
value = self.evaluate(current_node)
logger.debug(f"Simulation complete. Final value: {value}")
return value
async def mcts_async(self, root_state: str) -> List[Node]:
root = Node(root_state, None)
tasks = []
for _ in range(self.num_rollouts):
tasks.append(self.mcts_rollout_async(root))
await asyncio.gather(*tasks)
return self.extract_trajectories(root)
async def mcts_rollout_async(self, root: Node):
node = root
while node.children:
node, _ = self.select_action(node)
action = random.choice(self.actions)
if len(node.children) < len(self.actions):
node = await self.expand_async(node, action)
value = await self.simulate_async(node)
self.backpropagate(node, value)
async def solve_async(self, question: str) -> str:
self.original_question = question
logger.info(f"Solving question: {question}")
trajectories = await self.mcts_async(question)
if not trajectories:
logger.warning("No trajectories found. Unable to solve the question.")
return "Unable to solve the question due to insufficient reasoning paths."
final_trajectory = self.select_final_trajectory(trajectories)
logger.debug(f"Final trajectory: {[node.state for node in final_trajectory]}")
answers = [self.extract_answer(node.state) for node in final_trajectory]
final_answer = self.select_best_answer(answers)
logger.info(f"Selected final answer: {final_answer}")
return final_answer, self.rstar_completion_tokens
def generate_response(self, prompt: str) -> str:
logger.debug(f"Generating response for prompt: {prompt[:100]}...")
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "system", "content": "You are a helpful assistant focused on solving mathematical problems. Stick to the given question and avoid introducing new scenarios."},
{"role": "user", "content": prompt}
],
max_tokens=4096,
temperature=0.2
)
self.rstar_completion_tokens += response.usage.completion_tokens
generated_response = response.choices[0].message.content.strip()
logger.debug(f"Generated response: {generated_response}")
return generated_response
def select_action(self, node: Node) -> Tuple[Node, str]:
if not node.children:
action = random.choice(self.actions)
logger.debug(f"Selected random action: {action}")
return node, action
uct_values = []
for child in node.children:
if child.visits == 0:
uct = float('inf')
else:
uct = child.value / child.visits + self.c * math.sqrt(math.log(node.visits) / child.visits)
uct_values.append(uct)
best_child = node.children[uct_values.index(max(uct_values))]
logger.debug(f"Selected action: {best_child.action}")
return best_child, best_child.action
def expand(self, node: Node, action: str) -> Node:
prompt = self.create_prompt(node.state, action)
new_state = self.generate_response(prompt)
child_node = Node(new_state, action, node)
node.children.append(child_node)
logger.debug(f"Expanded node with action: {action}")
return child_node
def simulate(self, node: Node) -> float:
current_node = node
depth = 0
logger.debug("Starting simulation")
while depth < self.max_depth:
if not current_node.children:
action = random.choice(self.actions)
current_node = self.expand(current_node, action)
else:
current_node = random.choice(current_node.children)
depth += 1
value = self.evaluate(current_node)
logger.debug(f"Simulation complete. Final value: {value}")
return value
def backpropagate(self, node: Node, value: float):
logger.debug("Starting backpropagation")
while node:
node.visits += 1
node.value += value
node = node.parent
logger.debug("Backpropagation complete")
def mcts(self, root_state: str) -> List[Node]:
root = Node(root_state, None)
logger.debug(f"Starting MCTS with {self.num_rollouts} rollouts")
for i in range(self.num_rollouts):
logger.debug(f"Rollout {i+1}/{self.num_rollouts}")
node = root
while node.children:
node, _ = self.select_action(node)
action = random.choice(self.actions)
if len(node.children) < len(self.actions):
node = self.expand(node, action)
value = self.simulate(node)
self.backpropagate(node, value)
logger.debug("MCTS complete")
return self.extract_trajectories(root)
def extract_trajectories(self, root: Node) -> List[List[Node]]:
logger.debug("Extracting trajectories")
trajectories = []
stack = [(root, [])]
while stack:
node, path = stack.pop()
if not node.children:
trajectories.append(path + [node])
else:
for child in node.children:
stack.append((child, path + [node]))
logger.debug(f"Extracted {len(trajectories)} trajectories")
return trajectories
def mutual_consistency(self, trajectory: List[Node]) -> bool:
split_index = random.randint(1, len(trajectory) - 1)
partial_trajectory = trajectory[:split_index]
prompt = self.create_discriminator_prompt(partial_trajectory)
completion = self.generate_response(prompt)
is_consistent = self.compare_completions(completion, trajectory[split_index:])
logger.debug(f"Mutual consistency check: {'Passed' if is_consistent else 'Failed'}")
return is_consistent
def select_final_trajectory(self, trajectories: List[List[Node]]) -> List[Node]:
logger.debug("Selecting final trajectory")
valid_trajectories = [t for t in trajectories if self.mutual_consistency(t)]
logger.debug(f"Found {len(valid_trajectories)} valid trajectories")
if not valid_trajectories:
logger.warning("No valid trajectories found. Selecting based on value/visits.")
return max(trajectories, key=lambda t: self.trajectory_score(t))
return max(valid_trajectories, key=lambda t: self.trajectory_score(t))
def trajectory_score(self, trajectory: List[Node]) -> float:
if not trajectory:
return float('-inf')
last_node = trajectory[-1]
if last_node.visits == 0:
return last_node.value # Return just the value if visits is zero
return last_node.value / last_node.visits
def select_best_answer(self, answers: List[Tuple[str, float]]) -> str:
valid_answers = [(answer, conf) for answer, conf in answers if answer]
if not valid_answers:
return "Unable to determine a valid answer."
# Sort by confidence and then by frequency
answer_counts = {}
for answer, conf in valid_answers:
if answer in answer_counts:
answer_counts[answer] = (answer_counts[answer][0] + 1, max(answer_counts[answer][1], conf))
else:
answer_counts[answer] = (1, conf)
sorted_answers = sorted(answer_counts.items(), key=lambda x: (-x[1][1], -x[1][0]))
best_answer, (count, conf) = sorted_answers[0]
logger.debug(f"Selected best answer: {best_answer} (count: {count}, confidence: {conf})")
return best_answer
def create_prompt(self, state: str, action: str) -> str:
question = self.original_question if hasattr(self, 'original_question') else "the original question"
prompts = {
"A1": f"""Given the current state: {state}
Generate the next logical step in solving {question}.
Your response should be a single, clear thought that moves towards the solution.
If you can determine the final answer at this step, state it clearly.""",
"A2": f"""Given the current state: {state}
Continue the reasoning process to solve {question}.
Provide the remaining steps needed to reach the final answer.
Each step should be clear and directly related to solving the problem.""",
"A3": f"""Given the current state: {state}
Identify a key sub-question that needs to be answered to solve {question}.
State this sub-question clearly, then provide its answer.
Explain how this sub-question and its answer contribute to solving the main problem.""",
"A4": f"""Given the current state: {state}
Re-examine the previous step in solving {question} using Chain-of-Thought reasoning.
Break down your thinking process explicitly, showing each logical step.
If you reach a conclusion, state it clearly.""",
"A5": f"""Given the current state: {state}
Rephrase {question} by clearly listing all relevant conditions and unknowns.
Ensure that your rephrasing captures all important details from the original question.
This rephrasing should help clarify the problem and guide the solution process."""
}
prompt = prompts[action] + "\n\nIf you determine the final answer, explicitly state 'The final answer is [your numeric answer]' at the end of your response."
logger.debug(f"Created prompt for action {action}: {prompt}")
return prompt
def create_discriminator_prompt(self, partial_trajectory: List[Node]) -> str:
states = [node.state for node in partial_trajectory]
partial_reasoning = " ".join(states)
return f"Given the partial reasoning:\n{partial_reasoning}\nComplete the reasoning to solve the problem:"
def compare_completions(self, completion: str, remaining_trajectory: List[Node]) -> bool:
remaining_states = [node.state for node in remaining_trajectory]
remaining_reasoning = " ".join(remaining_states)
# Normalize both strings: remove punctuation, convert to lowercase, and split into words
completion_words = set(completion.lower().replace('.', '').replace(',', '').split())
trajectory_words = set(remaining_reasoning.lower().replace('.', '').replace(',', '').split())
# Calculate word overlap
overlap = len(completion_words.intersection(trajectory_words))
total_words = len(completion_words.union(trajectory_words))
# Consider it a match if there's more than 70% word overlap
return overlap / total_words > 0.7
def evaluate(self, node: Node) -> float:
# Extract the final answer from the node's state
answer, confidence = self.extract_answer(node.state)
# Check if the answer is a number
try:
float(answer)
logger.debug(f"Evaluated node. Answer: {answer}, Confidence: {confidence}, Value: {confidence}")
return confidence # Return the confidence as the value
except ValueError:
logger.debug(f"Evaluated node. Answer: {answer}, Confidence: {confidence}, Value: 0.0")
return 0.0 # If it's not a valid number, return a low score
def extract_answer(self, final_state: str) -> Tuple[str, float]:
logger.debug(f"Extracting answer from state: {final_state}")
patterns = [
r"The answer is (\d+)",
r"The final answer is (\d+)",
r"Therefore, the answer is (\d+)",
r"So, the answer is (\d+)",
r"Thus, the answer is (\d+)",
r"In conclusion, the answer is (\d+)",
]
for pattern in patterns:
match = re.search(pattern, final_state)
if match:
answer = match.group(1)
confidence = 1.0
logger.debug(f"Answer found using pattern '{pattern}': {answer}")
return answer, confidence
# If no pattern is found, try to extract any number
numbers = re.findall(r'\d+', final_state)
if numbers:
answer = numbers[-1] # Take the last number found
confidence = 0.5 # Lower confidence as it's not in the expected format
logger.debug(f"No pattern found. Using last number as answer: {answer}")
return answer, confidence
logger.warning("No answer found in the state.")
return "", 0.0
def solve(self, question: str) -> str:
"""
Synchronous wrapper for solve_async method.
"""
return asyncio.run(self.solve_async(question))