diff --git a/tree-of-thoughts.ipynb b/tree-of-thoughts.ipynb new file mode 100644 index 0000000000..6856286921 --- /dev/null +++ b/tree-of-thoughts.ipynb @@ -0,0 +1,741 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tree of Thoughts for problem solving with large language models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "TLDR: This blog post is about using \"Tree of Thoughts\", a tree-based framework to solve the Game of 24 tasks with a large language model." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Tree of Thoughts (ToT) is a framework used by LLMs to solve complex reasoning problems. The intermediate steps in a reasoning process are split into \"thoughts\" as similar to Chain of Thought, but there are multiple thoughts generated per step, resulting in a tree-like structure. A search algorithm is implemented allowing ToT to explore among the thoughts." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Load Model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this tutorial, we'll use two different LLMs: Mistral and GPT-4. \n", + "\n", + "We can use the Hugging face ```transformers``` library to generate text with our LLMs. First, we start off by importing the necessary libraries." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3a9253adff2a409084d62fe896a7b04e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(HTML(value='
list:\n", + " \n", + " messages = [{\"role\": \"user\", \"content\": prompt}]\n", + " outputs = []\n", + "\n", + " response = client.chat.completions.create(model=model, messages=messages, temperature=temperature, max_tokens=max_tokens, n=n, stop=stop)\n", + " \n", + " for choice in response.choices:\n", + " outputs.extend([choice.message.content])\n", + "\n", + " return outputs " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Implementing Tree of Thoughts (ToT) " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "ToT can be broken down into 4 key steps as described below. \n", + "\n", + "(Step 0): Thought Decomposition\n", + "\n", + "(Step 1): Thought Generation\n", + "- In this step, the LLM is prompted to generate thoughts by either one of two ways:\n", + " - Sample: The thoughts are generated by sampling i.i.d thoughts from a Chain of Thought prompt.\n", + " - Propose: The thoughts are propsed sequentially depending on the previous prompts. \n", + "\n", + "(Step 2): Thought Evaluation\n", + "- The LLMs are prompted to evaluate the thoughts generated in the previous step, by either: \n", + " - Value: The thoughts are assigned a score individually. \n", + " - Vote: All of thoughts are evaluated together and assigned a score.\n", + "\n", + "(Step 3): Search Algorithm\n", + "- The search algorithm is used to explore the thoughts generated in the previous steps:\n", + " - Breadth first search\n", + " - Depth first search\n", + "\n", + "\n", + "In this tutorial, we'll be using ToT with Mistral-7B-v0.3 to solve the Game of 24.\n", + "\n", + "The Game of 24 is a task where given a sequence of 4 numbers, we’ll need to find the correct mathematical operations (add, subtract, multiply, divide) that’ll lead to the number 24. For example, if the sequence is {4, 9, 10, 13}, the correct operations using the 4 numbers are: (10 - 4) * (13 - 9) = 24. Each number in the sequence can only be used once.\n", + "\n", + "In this tutorial, we'll be using 'Propose' for Thought Generation, 'Value' for Thought Evaluation and 'Breadth-first search' for the search algorithm." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.0 Thought Decomposition" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We define the prompts necessary for the ToT framework. Each prompt will be used in different stages of the ToT process. The propose_prompt is meant to guide the LLM to come up with possible next steps, given a certain point in the problem.\n", + "\n", + "The value_prompt assigns a classification (sure/likely/impossible) to each thought, depending on how likely it is to reach the number 24 given the current sequence of thoughts." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Prompts\n", + "\n", + "# 5-shot\n", + "standard_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to obtain 24.\n", + "Input: 4 4 6 8\n", + "Answer: (4 + 8) * (6 - 4) = 24\n", + "Input: 2 9 10 12\n", + "Answer: 2 * 12 * (10 - 9) = 24\n", + "Input: 4 9 10 13\n", + "Answer: (13 - 9) * (10 - 4) = 24\n", + "Input: 1 4 8 8\n", + "Answer: (8 / 4 + 1) * 8 = 24\n", + "Input: 5 5 5 9\n", + "Answer: 5 + 5 + 5 + 9 = 24\n", + "Input: {input}\n", + "'''\n", + "\n", + "\n", + "# PROMPTS FOR THOUGHT GENERATION\n", + "cot_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.\n", + "Input: 4 4 6 8\n", + "Steps:\n", + "4 + 8 = 12 (left: 4 6 12)\n", + "6 - 4 = 2 (left: 2 12)\n", + "2 * 12 = 24 (left: 24)\n", + "Answer: (6 - 4) * (4 + 8) = 24\n", + "Input: 2 9 10 12\n", + "Steps:\n", + "12 * 2 = 24 (left: 9 10 24)\n", + "10 - 9 = 1 (left: 1 24)\n", + "24 * 1 = 24 (left: 24)\n", + "Answer: (12 * 2) * (10 - 9) = 24\n", + "Input: 4 9 10 13\n", + "Steps:\n", + "13 - 10 = 3 (left: 3 4 9)\n", + "9 - 3 = 6 (left: 4 6)\n", + "4 * 6 = 24 (left: 24)\n", + "Answer: 4 * (9 - (13 - 10)) = 24\n", + "Input: 1 4 8 8\n", + "Steps:\n", + "8 / 4 = 2 (left: 1 2 8)\n", + "1 + 2 = 3 (left: 3 8)\n", + "3 * 8 = 24 (left: 24)\n", + "Answer: (1 + 8 / 4) * 8 = 24\n", + "Input: 5 5 5 9\n", + "Steps:\n", + "5 + 5 = 10 (left: 5 9 10)\n", + "10 + 5 = 15 (left: 9 15)\n", + "15 + 9 = 24 (left: 24)\n", + "Answer: ((5 + 5) + 5) + 9 = 24\n", + "Input: {input}\n", + "'''\n", + "\n", + "generate_prompt = '''Input: 2 8 8 14\n", + "Possible next steps:\n", + "2 + 8 = 10 (left: 8 10 14)\n", + "8 / 2 = 4 (left: 4 8 14)\n", + "14 + 2 = 16 (left: 8 8 16)\n", + "2 * 8 = 16 (left: 8 14 16)\n", + "8 - 2 = 6 (left: 6 8 14)\n", + "14 - 8 = 6 (left: 2 6 8)\n", + "14 / 2 = 7 (left: 7 8 8)\n", + "14 - 2 = 12 (left: 8 8 12)\n", + "Input: {input}\n", + "Possible next steps:\n", + "'''\n", + "\n", + "# PROMPTS FOR THOUGHT EVALUATION\n", + "\n", + "value_prompt = '''Evaluate if given numbers can reach 24 (sure/likely/impossible)\n", + "10 14\n", + "10 + 14 = 24\n", + "sure\n", + "11 12\n", + "11 + 12 = 23\n", + "12 - 11 = 1\n", + "11 * 12 = 132\n", + "11 / 12 = 0.91\n", + "impossible\n", + "4 4 10\n", + "4 + 4 + 10 = 8 + 10 = 18\n", + "4 * 10 - 4 = 40 - 4 = 36\n", + "(10 - 4) * 4 = 6 * 4 = 24\n", + "sure\n", + "4 9 11\n", + "9 + 11 + 4 = 20 + 4 = 24\n", + "sure\n", + "5 7 8\n", + "5 + 7 + 8 = 12 + 8 = 20\n", + "(8 - 5) * 7 = 3 * 7 = 21\n", + "I cannot obtain 24 now, but numbers are within a reasonable range\n", + "likely\n", + "5 6 6\n", + "5 + 6 + 6 = 17\n", + "(6 - 5) * 6 = 1 * 6 = 6\n", + "I cannot obtain 24 now, but numbers are within a reasonable range\n", + "likely\n", + "10 10 11\n", + "10 + 10 + 11 = 31\n", + "(11 - 10) * 10 = 10\n", + "10 10 10 are all too big\n", + "impossible\n", + "1 3 3\n", + "1 * 3 * 3 = 9\n", + "(1 + 3) * 3 = 12\n", + "1 3 3 are all too small\n", + "impossible\n", + "{input}\n", + "'''\n", + "\n", + "value_last_step_prompt = '''Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Given an input and an answer, give a judgement (sure/impossible) if the answer is correct, i.e. it uses each input exactly once and no other numbers, and reach 24.\n", + "Input: 4 4 6 8\n", + "Answer: (4 + 8) * (6 - 4) = 24\n", + "Judge: \n", + "sure\n", + "Input: 2 9 10 12\n", + "Answer: 2 * 12 * (10 - 9) = 24\n", + "Judge: \n", + "sure\n", + "Input: 4 9 10 13\n", + "Answer: (13 - 9) * (10 - 4) = 24\n", + "Judge: \n", + "sure\n", + "Input: 4 4 6 8\n", + "Answer: (4 + 8) * (6 - 4) + 1 = 25\n", + "Judge: \n", + "impossible\n", + "Input: 2 9 10 12\n", + "Answer: 2 * (12 - 10) = 24\n", + "Judge: \n", + "impossible\n", + "Input: 4 9 10 13\n", + "Answer: (13 - 4) * (10 - 9) = 24\n", + "Judge: \n", + "impossible\n", + "Input: {input}\n", + "Answer: {answer}\n", + "Judge:'''" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll start implementing our ToT algorithm. We'll define a function for each core part of the ToT algorithm.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.1 Thought Generation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we'll define functions necessary for \"Thought Generation\".\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def get_current_numbers(y: str) -> str:\n", + "\n", + " last_line = y.strip().split('\\n')[-1]\n", + " return last_line.split('left: ')[-1].split(')')[0]\n", + "\n", + "def prepare_generate_prompt(current_numbers, thought, data):\n", + " \n", + " if current_numbers == '24':\n", + " prompt = cot_prompt.format(input=data) + 'Steps: ' + thought\n", + " else:\n", + " prompt = generate_prompt.format(input=current_numbers)\n", + "\n", + " return prompt\n", + "\n", + "\n", + "def generate_thoughts(data, thoughts):\n", + "\n", + " new_thoughts = []\n", + " \n", + " for thought in thoughts:\n", + "\n", + " # Prepare prompt\n", + " current_numbers = get_current_numbers(thought if thought else data)\n", + " prompt = prepare_generate_prompt(current_numbers, thought, data)\n", + " \n", + " # Generate thoughts with prompt\n", + " proposals = mistral(prompt).split('\\n')\n", + " #proposals = gpt(prompt, n=1, stop=None)[0].split('\\n')\n", + " new_thoughts.extend([thought + _ + '\\n' for _ in proposals])\n", + "\n", + " return new_thoughts" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.2 Thought Evaluation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll create the functions necessary for \"Thought Evaluation\", where each of the thoughts are evaluated by the LLM." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, we create the ```prepare_evaluate_prompt``` function which turns our current thought into an evaluation prompt by using the ```value_prompt``` from above." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def prepare_evaluate_prompt(data: str, thought: str) -> str:\n", + " last_line = thought.strip().split('\\n')[-1]\n", + " if 'left: ' not in last_line:\n", + " ans = last_line.lower().replace('answer: ', '')\n", + " return value_last_step_prompt.format(input=data, answer=ans) \n", + " current_numbers = get_current_numbers(thought)\n", + " return value_prompt.format(input=current_numbers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We then create the ```evaluate_outputs_unwrap``` which converts the values assigned to each thought into a list of integers." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate_outputs_unwrap(thought: str, evaluate_outputs: list) -> float:\n", + " if len(thought.strip().split('\\n')) == 4 and 'answer' not in thought.lower():\n", + " return 0\n", + " value_names = [_.split('\\n')[-1] for _ in evaluate_outputs]\n", + " value_map = {'impossible': 0.001, 'likely': 1, 'sure': 20}\n", + " score = sum(value * value_names.count(name) for name, value in value_map.items())\n", + " return score" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And, finally we wrap the above functions into the ```evaluate``` function." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate_thoughts(data, thoughts, n_evaluate_sample):\n", + " scores = []\n", + " for thought in thoughts:\n", + " evaluate_prompt = prepare_evaluate_prompt(data, thought)\n", + " #evaluate_outputs = gpt(evaluate_prompt, n=n_evaluate_sample, stop=None)\n", + " evaluate_outputs = mistral(evaluate_prompt)\n", + " score = evaluate_outputs_unwrap(thought, evaluate_outputs)\n", + " scores.append(score)\n", + " return scores" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.3 Search algorithm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, we'll implement the \"Search Algorithm\" which will be used to search through the thoughts generated by the LLM." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def search_algorithm(new_thoughts, ids, scores, n_select_sample):\n", + " selected_ids = sorted(ids, key=lambda x: scores[x], reverse=True)[:n_select_sample] # Take top n_select_sample from list based on scores\n", + " select_new_thoughts = [new_thoughts[select_id] for select_id in selected_ids] \n", + " return select_new_thoughts" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Run ToT with sample data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll test our implementation with some sample data i.e the sequence 4 5 6 10. We'll comebine the functions from above. If ToT works sucessfully, it should output the operations that can be performed to reach 24." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "data = '4 5 6 10'" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import partial\n", + "\n", + "\n", + "def solve(model, temperature, n_evaluate_sample, n_select_sample):\n", + "\n", + " #global gpt\n", + " #gpt = partial(gpt, model=model, temperature=temperature)\n", + " \n", + " thoughts = ['']\n", + " data = '4 5 6 10'\n", + "\n", + " steps = 1\n", + "\n", + " for step in range(steps):\n", + "\n", + " print('Step Number ::', step)\n", + " print('(Step 0) Current Thoughts: ', thoughts)\n", + "\n", + " # Step 1: Thought Generation\n", + " new_thoughts = generate_thoughts(data, thoughts)\n", + " ids = list(range(len(new_thoughts)))\n", + "\n", + " print('(Step 1) New Thoughts: ', new_thoughts)\n", + " print('(Step 1) ids ', ids)\n", + "\n", + " # Step 2: Thought Evaluation\n", + " scores = evaluate_thoughts(data, new_thoughts, n_evaluate_sample)\n", + " print('(Step 2) Scores: ', scores)\n", + "\n", + " # Step 3: Search algorithm\n", + " \n", + " selected_new_thoughts = search_algorithm(new_thoughts, ids, scores, n_select_sample) \n", + " print('(Step 3) Selected new thoughts: ', selected_new_thoughts)\n", + "\n", + " thoughts = selected_new_thoughts\n", + "\n", + " print('--------')\n", + "\n", + " return thoughts" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step Number :: 0\n", + "(Step 0) Current Thoughts: ['']\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(Step 1) New Thoughts: ['Input: 2 8 8 14\\n', 'Possible next steps:\\n', '2 + 8 = 10 (left: 8 10 14)\\n', '8 / 2 = 4 (left: 4 8 14)\\n', '14 + 2 = 16 (left: 8 8 16)\\n', '2 * 8 = 16 (left: 8 14 16)\\n', '8 - 2 = 6 (left: 6 8 14)\\n', '14 - 8 = 6 (left: 2 6 8)\\n', '14 / 2 = 7 (left: 7 8 8)\\n', '14 - 2 = 12 (left: 8 8 12)\\n', 'Input: 4 5 6 10\\n', 'Possible next steps:\\n', '4 + 5 = 9 (left: 5 6 10)\\n', '5 / 4 = 1 (left: 1 5 6)\\n', '6 + 5 = 11 (left: 5 6 11)\\n', '6 - 5 = 1 (left: 1 6 11)\\n', '10 + 4 = 14 (left: 4 6 14)\\n', '4 * \\n']\n", + "(Step 1) ids [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n", + "Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(Step 2) Scores: [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]\n", + "(Step 3) Selected new thoughts: ['Input: 2 8 8 14\\n', 'Possible next steps:\\n', '2 + 8 = 10 (left: 8 10 14)\\n', '8 / 2 = 4 (left: 4 8 14)\\n', '14 + 2 = 16 (left: 8 8 16)\\n']\n", + "--------\n" + ] + }, + { + "data": { + "text/plain": [ + "['Input: 2 8 8 14\\n',\n", + " 'Possible next steps:\\n',\n", + " '2 + 8 = 10 (left: 8 10 14)\\n',\n", + " '8 / 2 = 4 (left: 4 8 14)\\n',\n", + " '14 + 2 = 16 (left: 8 8 16)\\n']" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "solve(model='gpt-3.5-turbo', temperature=0.7, n_evaluate_sample=3, n_select_sample=5)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:.conda-rahul_env]", + "language": "python", + "name": "conda-env-.conda-rahul_env-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}