-
Notifications
You must be signed in to change notification settings - Fork 4.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MCTS] Add self-refined MCTS (#6098)
* add reasoner * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update code * delete llama * update prompts * update readme * update readme --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
4294ae8
commit 89a9a60
Showing
5 changed files
with
324 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
26 changes: 26 additions & 0 deletions
26
applications/ColossalChat/coati/reasoner/guided_search/llm.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import openai | ||
from openai.types.chat.chat_completion import ChatCompletion | ||
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam | ||
|
||
API_KEY = "Dummy API Key" | ||
|
||
|
||
def get_client(base_url: str | None = None) -> openai.Client: | ||
return openai.Client(api_key=API_KEY, base_url=base_url) | ||
|
||
|
||
def chat_completion( | ||
messages: list[ChatCompletionMessageParam], | ||
model: str, | ||
base_url: str | None = None, | ||
temperature: float = 0.8, | ||
**kwargs, | ||
) -> ChatCompletion: | ||
client = get_client(base_url) | ||
response = client.chat.completions.create( | ||
model=model, | ||
messages=messages, | ||
temperature=temperature, | ||
**kwargs, | ||
) | ||
return response |
250 changes: 250 additions & 0 deletions
250
applications/ColossalChat/coati/reasoner/guided_search/mcts.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,250 @@ | ||
""" | ||
Implementation of MCTS + Self-refine algorithm. | ||
Reference: | ||
1. "Accessing GPT-4 level Mathematical Olympiad Solutions via Monte | ||
Carlo Tree Self-refine with LLaMa-3 8B: A Technical Report" | ||
2. https://github.com/BrendanGraham14/mcts-llm/ | ||
3. https://github.com/trotsky1997/MathBlackBox/ | ||
4. https://github.com/openreasoner/openr/blob/main/reason/guided_search/tree.py | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import math | ||
from collections import deque | ||
|
||
import numpy as np | ||
import tqdm | ||
from coati.reasoner.guided_search.llm import chat_completion | ||
from coati.reasoner.guided_search.prompt_store.base import PromptCFG | ||
from pydantic import BaseModel | ||
|
||
|
||
class MCTSNode(BaseModel): | ||
""" | ||
Node for MCTS. | ||
""" | ||
|
||
answer: str | ||
parent: MCTSNode = None | ||
children: list[MCTSNode] = [] | ||
num_visits: int = 0 | ||
Q: int = 0 | ||
rewards: list[int] = [] | ||
|
||
def expand_node(self, node) -> None: | ||
self.children.append(node) | ||
|
||
def add_reward(self, reward: int) -> None: | ||
self.rewards.append(reward) | ||
self.Q = (np.min(self.rewards) + np.mean(self.rewards)) / 2 | ||
|
||
|
||
class MCTS(BaseModel): | ||
""" | ||
Simulation of MCTS process. | ||
""" | ||
|
||
problem: str | ||
max_simulations: int | ||
cfg: PromptCFG | ||
C: float = 1.4 | ||
max_children: int = 2 | ||
epsilon: float = 1e-5 | ||
root: MCTSNode = None | ||
|
||
def initialization(self): | ||
""" | ||
Root Initiation. | ||
""" | ||
# Dummy answer as root. | ||
base_answer = self.sample_base_answer() | ||
self.root = MCTSNode(answer=base_answer) | ||
self.self_evaluate(self.root) | ||
|
||
def is_fully_expanded(self, node: MCTSNode): | ||
return len(node.children) >= self.max_children or any(child.Q > node.Q for child in node.children) | ||
|
||
def select_node(self) -> MCTSNode: | ||
""" | ||
Select next node to explore. | ||
""" | ||
candidates: list[MCTSNode] = [] | ||
to_explore = deque([self.root]) | ||
|
||
while to_explore: | ||
current_node = to_explore.popleft() | ||
if not self.is_fully_expanded(current_node): | ||
candidates.append(current_node) | ||
to_explore.extend(current_node.children) | ||
|
||
if not candidates: | ||
return self.root | ||
|
||
return max(candidates, key=self.compute_uct) | ||
|
||
def self_evaluate(self, node: MCTSNode): | ||
""" | ||
Sample reward of the answer. | ||
""" | ||
reward = self.sample_reward(node) | ||
node.add_reward(reward) | ||
|
||
def back_propagation(self, node: MCTSNode): | ||
""" | ||
Back propagate the value of the refined answer. | ||
""" | ||
parent = node.parent | ||
while parent: | ||
best_child_Q = max(child.Q for child in parent.children) | ||
parent.Q = (parent.Q + best_child_Q) / 2 | ||
parent.num_visits += 1 | ||
parent = parent.parent | ||
|
||
def compute_uct(self, node: MCTSNode): | ||
""" | ||
Compute UCT. | ||
""" | ||
if node.parent is None: | ||
return -100 | ||
return node.Q + self.C * math.sqrt(math.log(node.parent.num_visits + 1) / (node.num_visits + self.epsilon)) | ||
|
||
def simulate(self): | ||
self.initialization() | ||
for _ in tqdm.tqdm(range(self.max_simulations)): | ||
node = self.select_node() | ||
child = self.self_refine(node) | ||
node.expand_node(child) | ||
self.self_evaluate(child) | ||
self.back_propagation(child) | ||
|
||
return self.get_best_answer() | ||
|
||
def get_best_answer(self): | ||
to_visit = deque([self.root]) | ||
best_node = self.root | ||
|
||
while to_visit: | ||
current_node = to_visit.popleft() | ||
if current_node.Q > best_node.Q: | ||
best_node = current_node | ||
to_visit.extend(current_node.children) | ||
|
||
return best_node.answer | ||
|
||
def self_refine(self, node: MCTSNode): | ||
""" | ||
Refine node. | ||
""" | ||
critique_response = chat_completion( | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": self.cfg.critic_system_prompt, | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "\n\n".join( | ||
[ | ||
f"<problem>\n{self.problem}\n</problem>", | ||
f"<current_answer>\n{node.answer}\n</current_answer>", | ||
] | ||
), | ||
}, | ||
], | ||
model=self.cfg.model, | ||
base_url=self.cfg.base_url, | ||
max_tokens=self.cfg.max_tokens, | ||
) | ||
critique = critique_response.choices[0].message.content | ||
assert critique is not None | ||
refined_answer_response = chat_completion( | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": self.cfg.refine_system_prompt, | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "\n\n".join( | ||
[ | ||
f"<problem>\n{self.problem}\n</problem>", | ||
f"<current_answer>\n{node.answer}\n</current_answer>", | ||
f"<critique>\n{critique}\n</critique>", | ||
] | ||
), | ||
}, | ||
], | ||
model=self.cfg.model, | ||
base_url=self.cfg.base_url, | ||
max_tokens=self.cfg.max_tokens, | ||
) | ||
refined_answer = refined_answer_response.choices[0].message.content | ||
assert refined_answer is not None | ||
|
||
return MCTSNode(answer=refined_answer, parent=node) | ||
|
||
def sample_base_answer(self): | ||
response = chat_completion( | ||
messages=[ | ||
{ | ||
"role": "system", | ||
"content": "The user will provide a problem. Solve the problem. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. \nThe answer is [answer] \n#### [answer].", | ||
}, | ||
{ | ||
"role": "user", | ||
"content": f"<problem>\n {self.problem} \n</problem> \nLet's think step by step", | ||
}, | ||
], | ||
model=self.cfg.model, | ||
base_url=self.cfg.base_url, | ||
max_tokens=self.cfg.max_tokens, | ||
) | ||
assert response.choices[0].message.content is not None | ||
return response.choices[0].message.content | ||
|
||
def sample_reward(self, node: MCTSNode): | ||
""" | ||
Calculate reward. | ||
""" | ||
messages = [ | ||
{ | ||
"role": "system", | ||
"content": self.cfg.evaluate_system_prompt, | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "\n\n".join( | ||
[ | ||
f"<problem>\n{self.problem}\n</problem>", | ||
f"<answer>\n{node.answer}\n</answer>", | ||
] | ||
), | ||
}, | ||
] | ||
for attempt in range(3): | ||
try: | ||
response = chat_completion( | ||
messages=messages, | ||
model=self.cfg.model, | ||
base_url=self.cfg.base_url, | ||
max_tokens=self.cfg.max_tokens, | ||
) | ||
assert response.choices[0].message.content is not None | ||
return int(response.choices[0].message.content) | ||
except ValueError: | ||
messages.extend( | ||
[ | ||
{ | ||
"role": "assistant", | ||
"content": response.choices[0].message.content, | ||
}, | ||
{ | ||
"role": "user", | ||
"content": "Failed to parse reward as an integer.", | ||
}, | ||
] | ||
) | ||
if attempt == 2: | ||
raise |
10 changes: 10 additions & 0 deletions
10
applications/ColossalChat/coati/reasoner/guided_search/prompt_store/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from pydantic import BaseModel | ||
|
||
|
||
class PromptCFG(BaseModel): | ||
model: str | ||
base_url: str | ||
max_tokens: int = 4096 | ||
critic_system_prompt: str | ||
refine_system_prompt: str | ||
evaluate_system_prompt: str |
20 changes: 20 additions & 0 deletions
20
applications/ColossalChat/coati/reasoner/guided_search/prompt_store/qwen.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
""" | ||
Prompts for Qwen Series. | ||
""" | ||
|
||
from coati.reasoner.guided_search.prompt_store.base import PromptCFG | ||
|
||
Qwen32B_prompt_CFG = PromptCFG( | ||
base_url="http://0.0.0.0:8008/v1", | ||
model="Qwen2.5-32B-Instruct", | ||
critic_system_prompt="Provide a detailed and constructive critique to improve the answer. " | ||
"Highlight specific areas that need refinement or correction.", | ||
refine_system_prompt="""# Instruction | ||
Refine the answer based on the critique. The response should begin with [reasoning process]...[Verification]... and end with [Final Answer]. | ||
""", | ||
evaluate_system_prompt=( | ||
"Analyze this answer strictly and critic, provide a reward score between -100 and 100 for the answer quality, using very strict standards. " | ||
"Do not give a full score above 95. Make sure the reward score is an integer. " | ||
"Return *ONLY* the score." | ||
), | ||
) |