Skip to content

Commit

Permalink
SN1-336: Create multi step reasoning task (#466)
Browse files Browse the repository at this point in the history
Co-authored-by: bkb2135 <98138173+bkb2135@users.noreply.github.com>
Co-authored-by: richwardle <richard.wardle@macrocosmos.ai>
Co-authored-by: Dmytro Bobrenko <17252809+dbobrenko@users.noreply.github.com>
  • Loading branch information
4 people authored Dec 2, 2024
1 parent 2f50da2 commit 11df291
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 20 deletions.
3 changes: 1 addition & 2 deletions prompting/api/gpt_endpoints/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def proxy_chat_completions(request: Request, api_key_data: dict = Depends(
):
raise HTTPException(
status_code=503,
detail=f"No miners available for model: {body.get('model')} and task: {task.__name__}",
detail=f"No miners available for model: {body.get('model')} and task: {task.__class__.__name__}",
)

response = query_miners(available_miners, json.dumps(body).encode("utf-8"), stream=stream)
Expand Down Expand Up @@ -113,5 +113,4 @@ async def proxy_chat_completions(request: Request, api_key_data: dict = Depends(
step=-1,
task_id=task.task_id,
)

return [res.model_dump() for res in response]
2 changes: 1 addition & 1 deletion prompting/llms/apis/sn19_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def chat_complete(
"stream": stream,
"logprobs": logprobs,
}
response = requests.post(url, headers=headers, data=json.dumps(data))
response = requests.post(url, headers=headers, data=json.dumps(data), timeout=30)
try:
response_json = response.json()
try:
Expand Down
160 changes: 160 additions & 0 deletions prompting/tasks/multi_step_reasoning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import json
import re
import time
from typing import ClassVar

from loguru import logger

from prompting.datasets.base import Context
from prompting.llms.apis.gpt_wrapper import LLMMessage, LLMMessages
from prompting.llms.apis.llm_wrapper import LLMWrapper
from prompting.rewards.relevance import RelevanceRewardModel
from prompting.rewards.reward import BaseRewardConfig, BaseRewardModel
from prompting.tasks.qa import QuestionAnsweringTask
from prompting.utils.cleaners import CleanerPipeline, PruneEnding, RemovePostQuestionText, RemoveQuotes, RemoveRoles
from prompting.utils.timer import Timer

MAX_THINKING_STEPS = 10


def make_api_call(messages, max_tokens, is_final_answer=False):
# TOOD: Make this use local model to prevent relay mining
for attempt in range(3):
try:
response = LLMWrapper.chat_complete(messages=LLMMessages(*messages))
return json.loads(re.sub("```", "", re.sub(r"```json\s*", "", response)))
except Exception as e:
if attempt == 2:
if is_final_answer:
return {
"title": "Error",
"content": f"Failed to generate final answer after 3 attempts. Error: {str(e)}",
}
else:
return {
"title": "Error",
"content": f"Failed to generate step after 3 attempts. Error: {str(e)}",
"next_action": "final_answer",
}
time.sleep(1) # Wait for 1 second before retrying


def generate_response(prompt):
messages = [
LLMMessage(
role="system",
content="""You are an expert AI assistant with advanced reasoning capabilities. Your task is to provide detailed, step-by-step explanations of your thought process. For each step:
1. Provide a clear, concise title describing the current reasoning phase.
2. Elaborate on your thought process in the content section.
3. Decide whether to continue reasoning or provide a final answer.
Response Format:
Use JSON with keys: 'title', 'content', 'next_action' (values: 'continue' or 'final_answer')
Key Instructions:
- Employ at least 5 distinct reasoning steps.
- Acknowledge your limitations as an AI and explicitly state what you can and cannot do.
- Actively explore and evaluate alternative answers or approaches.
- Critically assess your own reasoning; identify potential flaws or biases.
- When re-examining, employ a fundamentally different approach or perspective.
- Utilize at least 3 diverse methods to derive or verify your answer.
- Incorporate relevant domain knowledge and best practices in your reasoning.
- Quantify certainty levels for each step and the final conclusion when applicable.
- Consider potential edge cases or exceptions to your reasoning.
- Provide clear justifications for eliminating alternative hypotheses.
Example of a valid JSON response:
```json
{
"title": "Initial Problem Analysis",
"content": "To approach this problem effectively, I'll first break down the given information into key components. This involves identifying...[detailed explanation]... By structuring the problem this way, we can systematically address each aspect.",
"next_action": "continue"
}```
""",
)
]
messages += [LLMMessage(role="user", content=prompt)]
messages += [
LLMMessage(
role="assistant",
content="Thank you! I will now think step by step following my instructions, starting at the beginning after decomposing the problem.",
)
]

steps = []
step_count = 1
total_thinking_time = 0

for _ in range(MAX_THINKING_STEPS):
with Timer() as timer:
step_data = make_api_call(messages, 300)
thinking_time = timer.final_time
total_thinking_time += thinking_time

steps.append((f"Step {step_count}: {step_data['title']}", step_data["content"], thinking_time))

messages.append(LLMMessage(role="assistant", content=json.dumps(step_data)))

if step_data["next_action"] == "final_answer" or not step_data.get("next_action"):
break

step_count += 1

# Yield after each step
yield steps, None

# Generate final answer
messages.append(LLMMessage(role="user", content="Please provide the final answer based on your reasoning above."))

start_time = time.time()
final_data = make_api_call(messages, 200, is_final_answer=True)
end_time = time.time()
thinking_time = end_time - start_time
total_thinking_time += thinking_time

steps.append(("Final Answer", final_data["content"], thinking_time))

yield steps, total_thinking_time


def execute_multi_step_reasoning(user_query):
for steps, total_thinking_time in generate_response(user_query):
if total_thinking_time is not None:
logger.info(f"**Total thinking time: {total_thinking_time:.2f} seconds**")
return steps, total_thinking_time


class MultiStepReasoningRewardConfig(BaseRewardConfig):
reward_definitions: ClassVar[list[BaseRewardModel]] = [
RelevanceRewardModel(weight=1),
]


class MultiStepReasoningTask(QuestionAnsweringTask):
"""QuestionAnsweringTasks must be initialised with an LLM pipeline to generate query and reference plus
context from a dataset to base the query on"""

cleaning_pipeline: ClassVar[CleanerPipeline] = CleanerPipeline(
cleaning_pipeline=[
RemoveQuotes(),
PruneEnding(),
RemoveRoles(),
RemovePostQuestionText(),
]
)
name: ClassVar[str] = "multi_step_reasoning"
augmentation_system_prompt: ClassVar[str] = ""
query: str | None = None
reference: str | None = None

def make_reference(self, dataset_entry: Context):
logger.info(f"Generating reference for Multi Step Reasoning task with query: {self.query}")
steps, total_thinking_time = execute_multi_step_reasoning(user_query=self.query)
logger.info(
f"**Steps: {steps}**, **Total thinking time for multi step reasoning: {total_thinking_time} seconds**"
)
logger.info(f"**Total thinking time for multi step reasoning: {total_thinking_time} seconds**")
self.reference = steps[-1][1]
return self.reference
15 changes: 0 additions & 15 deletions prompting/tasks/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,6 @@
{question}
"""

FOLLOWUP_REFERENCE_PROMPT_TEMPLATE = """\
You are a helpful assistant. Answer the question below in detail, prioritizing the use of the provided conversation history. The context is available for additional information if needed, but it may not always be relevant.
# Conversation History:
{history}
# Context (optional):
{context}
# Question:
{question}
Ensure your answer references relevant parts of the conversation history. Use the context only if it provides additional necessary information.
"""


class QARewardConfig(BaseRewardConfig):
reward_definitions: ClassVar[list[BaseRewardModel]] = [
Expand Down
11 changes: 9 additions & 2 deletions prompting/tasks/task_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from prompting.tasks.date_qa import DateQARewardConfig, DateQuestionAnsweringTask
from prompting.tasks.inference import InferenceRewardConfig, InferenceTask
from prompting.tasks.multi_choice import MultiChoiceRewardConfig, MultiChoiceTask
from prompting.tasks.multi_step_reasoning import MultiStepReasoningRewardConfig, MultiStepReasoningTask
from prompting.tasks.programming_task import ProgrammingRewardConfig, ProgrammingTask
from prompting.tasks.qa import QARewardConfig, QuestionAnsweringTask
from prompting.tasks.summarization import SummarizationRewardConfig, SummarizationTask
Expand All @@ -37,7 +38,7 @@ def __hash__(self):

class TaskRegistry(BaseModel):
task_configs: ClassVar[list[TaskConfig]] = [
TaskConfig(task=QuestionAnsweringTask, probability=0.2, datasets=[WikiDataset], reward_model=QARewardConfig),
TaskConfig(task=QuestionAnsweringTask, probability=0.15, datasets=[WikiDataset], reward_model=QARewardConfig),
TaskConfig(
task=SummarizationTask, probability=0.1, datasets=[WikiDataset], reward_model=SummarizationRewardConfig
),
Expand All @@ -55,7 +56,7 @@ class TaskRegistry(BaseModel):
),
TaskConfig(
task=MultiChoiceTask,
probability=0.31,
probability=0.26,
datasets=[WikiDataset],
reward_model=MultiChoiceRewardConfig,
),
Expand All @@ -71,6 +72,12 @@ class TaskRegistry(BaseModel):
datasets=[DDGDataset],
reward_model=WebRetrievalRewardConfig,
),
TaskConfig(
task=MultiStepReasoningTask,
probability=0.1,
datasets=[WikiDataset],
reward_model=MultiStepReasoningRewardConfig,
),
]

@classmethod
Expand Down

0 comments on commit 11df291

Please sign in to comment.