From 3e8e4915d930f3b9ecf57ca2639695de608f824f Mon Sep 17 00:00:00 2001 From: Leandro von Werra Date: Wed, 30 Aug 2023 11:44:06 +0200 Subject: [PATCH] TextEnvironments (#424) * WIP skeleton * minimal working poc * cleanup * rename variables * quick typo fix * add v1 masking (#429) * add v1 masking * working v1 * adapt from suggestion * avoid warning `Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.` * fix masking - mask the responses from API call only * quality * address comments * Update trl/environment/base.py Co-authored-by: Leandro von Werra * adapt a bit * wip on tokenization/masking in textenv * small fixes * update viz * add example * print debug text and pass masks * style * format and move tensor to device * update example * update example * This seems to work * fix masking * fix rich output to console --------- Co-authored-by: Costa Huang Co-authored-by: Leandro von Werra Co-authored-by: leandro * Add masking (#461) * add v1 masking * working v1 * adapt from suggestion * avoid warning `Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.` * fix masking - mask the responses from API call only * quality * address comments * Update trl/environment/base.py Co-authored-by: Leandro von Werra * adapt a bit * wip on tokenization/masking in textenv * small fixes * update viz * add example * print debug text and pass masks * style * format and move tensor to device * update example * update example * This seems to work * fix masking * fix rich output to console * fix batched generation * improve stopping criteria * improve error handling in tool call --------- Co-authored-by: younesbelkada Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Costa Huang * fix uknown tool * fix rewards and increase bs * remove unused script * ugly WIP fix * do not return modified obj for in-place operations * do not return modified obj for in-place operations * clean up stopping criterium * push updates * push update * format, add docs * rename file * add kwargs to reward fn * simplify example * simplify example * bug fix * add a trivia example * pre-commit * max tool response length * fix regex for multi-line * refactor tool exceptions * fix exceptions in tool * add docs * fix style * make rich optional * add docstrings * add tests * add TextEnv tests (WIP) * update triviaqa code * update docs * refactor text env * update tests (WIP) * add end2end test * update docs * upload tool demo * refactor * customizable system prompt * add text env docs * update index and toc * fix `TextHistory` show methods * add max length * fix style * fix typo * refactor to kwargs in init and tasks to queries * kwargs for reward docs * Update examples/triviaqa.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update examples/tool_demo.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update docs/source/learning_tools.mdx Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update docs/source/learning_tools.mdx Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update docs/source/learning_tools.mdx Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update docs/source/text_environments.md Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update examples/triviaqa.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update examples/triviaqa.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * move to tool folder * remove assets * remove tool demo * move rich import test to import utils * add copyright * fixes for masks in ppo trainer * add text env api docs * make precommit + add ppo test with mask * move examples and add python * fix style * update triviaqa example * add more docs * update docs * Update docs/source/learning_tools.mdx * Apply suggestions from code review * precommit --------- Co-authored-by: Costa Huang Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: younesbelkada Co-authored-by: leandro von werra --- docs/source/_toctree.yml | 4 + docs/source/index.mdx | 2 + docs/source/learning_tools.mdx | 229 +++++++++ docs/source/text_environments.md | 197 ++++++++ .../research_projects/tools/calculator.py | 119 +++++ .../tools/python_interpreter.py | 194 +++++++ examples/research_projects/tools/triviaqa.py | 189 +++++++ tests/test_environments.py | 273 ++++++++++ tests/test_ppo_trainer.py | 34 +- trl/__init__.py | 1 + trl/environment/__init__.py | 3 + trl/environment/base_environment.py | 473 ++++++++++++++++++ trl/import_utils.py | 4 + trl/trainer/ppo_trainer.py | 56 ++- 14 files changed, 1769 insertions(+), 9 deletions(-) create mode 100644 docs/source/learning_tools.mdx create mode 100644 docs/source/text_environments.md create mode 100644 examples/research_projects/tools/calculator.py create mode 100644 examples/research_projects/tools/python_interpreter.py create mode 100644 examples/research_projects/tools/triviaqa.py create mode 100644 tests/test_environments.py create mode 100644 trl/environment/__init__.py create mode 100644 trl/environment/base_environment.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 27a88bf5aaa..11795be4962 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -29,6 +29,8 @@ title: DPO Trainer - local: ddpo_trainer title: Denoising Diffusion Policy Optimization + - local: text_environments + title: Text Environments title: API - sections: - local: sentiment_tuning @@ -39,6 +41,8 @@ title: Detoxifying a Language Model - local: using_llama_models title: Training StackLlama + - local: learning_tools + title: Learning to Use Tools - local: multi_adapter_rl title: Multi Adapter RLHF title: Examples diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 4c75cf556cd..af1bcb8ca1e 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -21,6 +21,7 @@ Check the appropriate sections of the documentation depending on your needs: - [`PPOTrainer`](trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm* - [Best-of-N Sampling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model* - [`DPOTrainer`](trainer): *Direct Preference Optimization training using `DPOTrainer`.* +- [`TextEnvironment`](text_environment): *Text environment to train your model using tools with RL.* ## Examples @@ -28,6 +29,7 @@ Check the appropriate sections of the documentation depending on your needs: - [Training with PEFT](lora_tuning_peft): *Memory efficient RLHF training using adapters with PEFT* - [Detoxifying LLMs](detoxifying_a_lm): *Detoxify your language model through RLHF* - [StackLlama](using_llama_models): *End-to-end RLHF training of a Llama model on Stack exchange dataset* +- [Learning with Tools](learning_tools): *Walkthrough of using `TextEnvironments`* - [Multi-Adapter Training](multi_adapter_rl): *Use a single base model and multiple adapters for memory efficient end-to-end training* diff --git a/docs/source/learning_tools.mdx b/docs/source/learning_tools.mdx new file mode 100644 index 00000000000..a5ff8fce2d6 --- /dev/null +++ b/docs/source/learning_tools.mdx @@ -0,0 +1,229 @@ +# Learning Tools (Experimental ๐Ÿงช) + +Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://arxiv.org/abs/2302.04761) and [ToolBench](https://arxiv.org/pdf/2305.16504.pdf). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning. + + +Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools): + +| File | Description | +|---|---| +| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. | +| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. | +| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. | + + + +Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs. + + + +## Learning to Use a Calculator + + +The rough idea is as follows: + +1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calulated number: + ```python + from transformers import AutoTokenizer, load_tool + tool = load_tool("ybelkada/simple-calculator") + tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places + ``` +1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later. +1. Create a prompt on how to use the tools + ```python + # system prompt + prompt = """\ + What is 13.1-3? + + 13.1-310.1 + + Result=10.1 + + What is 4*3? + + 4*312 + + Result=12 + + What is 12.1+1? + + 12.1+113.1 + + Result=13.1 + + What is 12.1-20? + + 12.1-20-7.9 + + Result=-7.9""" + ``` +3. Create a `trl.TextEnvironment` with the model + ```python + env = TextEnvironment( + model, + tokenizer, + {"SimpleCalculatorTool": tool_fn}, + reward_fn, + prompt, + generation_kwargs=generation_kwargs, + ) + ``` +4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens. + ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools.png) +1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`. + +## Experiment results + +We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster. + +``` +WANDB_TAGS="calculator_final" python benchmark/benchmark.py \ + --command "python examples/calculator_few_shots_env.py" \ + --num-seeds 10 \ + --start-seed 1 \ + --workers 10 \ + --slurm-gpus-per-task 1 \ + --slurm-ntasks 1 \ + --slurm-total-cpus 8 \ + --slurm-template-path benchmark/trl.slurm_template +``` + +We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot. +``` +python -m openrlbenchmark.rlops_multi_metrics \ + --filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \ + 'wandb?tag=calculator_final&cl=calculator_mask' \ + --env-ids trl \ + --check-empty-runs \ + --pc.ncols 2 \ + --pc.ncols-legend 1 \ + --output-filename static/0compare \ + --scan-history +``` + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools_chart.png) + +As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task. + + +## (Early Experiments ๐Ÿงช): learning to use a wiki tool for question answering + +In the [ToolFormer](https://arxiv.org/abs/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset. + + + + +**Note that many settings are different so the results are not directly comparable.** + + + + + +### Building a search index + +Since [ToolFormer](https://arxiv.org/abs/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT) + +Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index. + +```python +from pyserini.search.lucene import LuceneSearcher +import json +searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc') +def search(query): + hits = searcher.search(query, k=1) + hit = hits[0] + contents = json.loads(hit.raw)['contents'] + return contents +print(search("tennis racket")) +``` +``` +Racket (sports equipment) +A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries. + +The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics. +... +``` + +We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later. + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pyserini.png) + +### Experiment settings + +We use the following settings: + +* use the `bigcode/starcoderbase` model as the base model +* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragrahs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool. +* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0. + * notice this is a simplified evaluation criteria. In [ToolFormer](https://arxiv.org/abs/2302.04761), the authors checks if the first 20 words of the response contain the correct answer. +* used the following prompt that demonstrates the usage of the wiki tool. +```python +prompt = """\ +Answer the following question: + +Q: In which branch of the arts is Patricia Neary famous? +A: Ballets +A2: Patricia NearyPatricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe. +Result=Ballets + +Q: Who won Super Bowl XX? +A: Chicago Bears +A2: Super Bowl XXSuper Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46โ€“10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans. +Result=Chicago Bears + +Q: """ +``` + + +### Result and Discussion + + +Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash. + +![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/triviaqa_learning_curves.png) + +Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection. + + +Note that the correct rate of the trained model is on the low end, which could be due to the following reasons: + +* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (1985โ€“1989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (1988โ€“2013) and other roles.[1][2]" + + + ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/real_first_name.png) + +* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act" + * Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies." + * [ToolFormer](https://arxiv.org/abs/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct. + + ![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/brown_act.png) + + +## (Early Experiments ๐Ÿงช): solving math puzzles with python interpreter + +In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following: + +```python +prompt = """\ +Example of using a Python API to solve math questions. + +Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? + + +def solution(): + money_initial = 23 + bagels = 5 + bagel_cost = 3 + money_spent = bagels * bagel_cost + money_left = money_initial - money_spent + result = money_left + return result +print(solution()) +72 + +Result = 72 + +Q: """ + + +Training results TBD. diff --git a/docs/source/text_environments.md b/docs/source/text_environments.md new file mode 100644 index 00000000000..851020e0f5c --- /dev/null +++ b/docs/source/text_environments.md @@ -0,0 +1,197 @@ +# Text Environments + +Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator. + +
+ +
+ +Let's dive into how text environments work and start with tools! + +## Tools + +One of the core building blocks of text environments are tools that the model can use to solve tasks. In general tools can be any Python function that takes a string as input and returns string. The `TextEnvironment` offers two options for tools: either go with predefined tools from `transformers.Tool` or define your own function or class with `__call__` method. Let's have a look at both! + +### `transformers.Tool` + +Text environments fully support tools of the class `transformers.Tool`. The advantage of building tools in that framework is that they can easily be shared + +```Python +from transformers import load_tool + +# simple calculator tool that runs +-/* operations +calc_tool = load_tool("ybelkada/simple-calculator") + +# python interpreter that executes program and returns outputs +py_tool = load_tool("lvwerra/python-interpreter") + +# wikipedia search index that returns best search match +wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc") +``` + +These tools are either loaded from the hub or from a local folder. Using the tool is as simple as calling them with a text query: + +```Python +calc_tool("1/2") +>>> "0.5" +``` + +Note that both input and return values are strings to enable easy usage with a language model. + +### Custom Tools + +The following is an example of a tool that adds two integers: + +```Python +def add(text): + int_1, int_2 = text.split("+") + result = int(int_1) + int(int_2) + return str(result) + +print(add("1+1")) +>>> "2" +``` + +We looked at basic examples such as a calculator but the principle holds for more complex tools as well such as a web search tool where you input the query and get the search results in return. Now let's look at how the model can use the tools with the call syntax. + +### Call syntax + +In order to have a unified way for the model to call a tool we created a simple syntax that looks as follows: + +```python +"QUERYTOOL_RESPONSE" +``` + +There are a few special tokens involved so let's decompose it: First the model can signal that it wants to use a tool by emitting the `` token. After that we want to know the name of the tool to call which is done by enclosing the tool name with `<>` brackets. Once we know which tool to call the tool query follows which is in free text form. The `` tokens signifies the end of the query and stops the model generation. At this point the model output is parsed and the query sent to the tool. The environment appends the tool response to the string followed by the `` token to show the end the tool output. + +Let's look at the concrete example of the calculator and assume its name is `Calculator` (more on how the name of a tool is inferred later): + +```python +"1/20.5" +``` + +Finally, the episode is ended and generation stops when the model generates `` which marks the interaction as completed. + +Now let's have a look how we can create a new text environment! + +## Create a `TextEnvironment` + + +```python +prompt = """\ +What is 13-3? +13-310.0 +Result=10 +""" + +def reward_fn(result, answer): + """Simplified reward function returning 1 if result matches answer and 0 otherwise.""" + result_parsed = result.split("=")[1].split("<")[0] + return int(result_parsed==answer) + +text_env = TextEnvironemnt( + model=model, + tokenizer=tokenizer, + tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")}, + reward_fn=exact_match_reward, + prompt=prompt, + max_turns=1 + max_tool_response=100 + generation_kwargs={"do_sample": "true"} +) +``` + +Let's decompose the settings: + +| Argument | Description | +|:-------------------|:----------------| +| `model` | Language model to interact with the environment and generate requests. | +| `tokenizer` | Tokenizer of language model handling tokenization of strings. | +| `tools` | `list` of `dict` of tools. If former the name of the tool is inferred from class name and otherwise it's the keys of the dictionary.| +| `reward_fn` | A function that takes a string as input and returns. Can have extra arguments that are passed to `.run()` such as ground truth.| +| `prompt` | Prompt to prepend to every task. Usually a few examples to demonstrate to the model how to use the tools in a few-shot fashion. | +| `max_turns` | Maximum number of interactions between model and tools before episode ends.| +| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.| +| `max_length` | The maximum number of tokens to allow in an episode. | +| `generation_kwargs`| Generation settings used by the language model. | + +You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools! + + +## Run an Episode + +To run a set of queries through the text environment one can simply use the `run` method. + +```python +queries = ["What is 1/2?"] +answers = ["0.5"] + +queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers) +``` + +This will execute the model/tool feedback loop for each query until either no tool is called anymore, the maximum number of turns is reached or to maximum number of tokens in an episode is exceeded. The extra `kwargs` (e.g. `answers=answers` above) passed to `run` will be passed on to the reward function. + +There are five objects that are returned by `run`: + +- `queries`: a list of the tokenized queries +- `responses`: all tokens that have been generated withing the environment including model and tool tokens +- `masks`: mask that indicates which tokens have been generated by the model and which tokens are generated by the tool +- `rewards`: a list of reward for each query/response +- `histories`: list of `TextHistory` objects, which are useful objects containing all the above and also the text equivalents + +The masks are crucial for training as we don't want to optimize tokens that the model has not generated which are tokens produced by the tools. + +Next, we'll train a PPO step with the generated responses! + + +### Train +Training on episodes from the `TextEnvironment` is straight forward and simply requires forwarding all the returned variables except the `TextHistory` objects to the `step` method: + +```python +train_stats = ppo_trainer.step(queries, responses, rewards, masks) +``` + +## `TextHistory` + +The `TextHistory` object stores the interactions between the model and the text environment. It stores tokens and text generated in each turn and their source in each turn (model or system) as well as rewards. Let's go through the class attributes and methods. + +### Attributes + +The following table summarises the available attributes of the `TextEnvironment` class: + +| Attribute | Description | +|:-------------------|:----------------| +| `text` | The full string of the text generated in the text environment with both model and system generated text. | +| `text_spans` | A list of tuples with the spans for each model or system generated text segment. | +| `system_spans` | A list of boolean values indicating if the segment is model or system generated. | +| `tokens` | All tokens generated in text environment with both model and system generated tokens. | +| `token_spans` | Similar to `text_spans` the `token_spans` indicate the boundaries of model andsystem generated tokens. | +| `token_masks` | The token masks can be used to ignore system generated tokens by masking them. | +| `completed` | Indicates if the interaction with the environment has completed. | +| `truncated` | Indicates if the interaction with the environment has completed because max length was reached. | + +With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look! + +### Visualization + +When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` libray](https://github.com/Textualize/rich) (make sure to install it before using these methods). + +You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`: + +
+ +
+ +Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`: + +
+ +
+ +Note that you can turn on the colour legend by passing `show_legend=True`. + +## API Documentation + +[[autodoc]] TextEnvironment + +[[autodoc]] TextHistory diff --git a/examples/research_projects/tools/calculator.py b/examples/research_projects/tools/calculator.py new file mode 100644 index 00000000000..deb5abfa716 --- /dev/null +++ b/examples/research_projects/tools/calculator.py @@ -0,0 +1,119 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import numpy as np +import torch +from transformers import AutoTokenizer, load_tool + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment + + +def generate_data(n): + """Generate random arithmetic tasks and answers.""" + tasks, answers = [], [] + for _ in range(n): + a = np.random.randint(0, 50) + b = np.random.randint(0, 50) + op = np.random.choice(["-", "+", "*"]) + tasks.append(f"\n\nWhat is {a} {op} {b}?") + if op == "-": + answers.append(a - b) + elif op == "+": + answers.append(a + b) + else: + answers.append(a * b) + return tasks, answers + + +def exact_match_reward(responses, answers=None): + """Reward if generated response contains correct answer.""" + rewards = [] + pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*" # generated by chatGPT + for response, answer in zip(responses, answers): + reward = 0.0 + predicted_number = None + match_pattern = re.findall(pattern, response) + if match_pattern: + predicted_number = float(match_pattern[0]) + if predicted_number is not None: + if np.abs(predicted_number - answer) < 0.01: + reward += 1.0 + rewards.append(torch.tensor(reward)) + return rewards + + +# set up models +model_id = "gpt2" +model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id) +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) +tokenizer.pad_token = tokenizer.eos_token + +# system prompt +prompt = """\ +What is 13-3? + +13-310.0 + +Result=10 + +What is 4*3? + +4*312.0 + +Result=12""" + +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "eos_token_id": -1, + "max_new_tokens": 32, +} + +# trainer +ppo_config = PPOConfig( + batch_size=256, + learning_rate=1.41e-5, + mini_batch_size=64, + log_with="wandb", +) +ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer) + +# text env +text_env = TextEnvironment( + model, + tokenizer, + {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")}, + exact_match_reward, + prompt, + generation_kwargs=generation_kwargs, +) + +# main training loop +for step in range(100): + tasks, answers = generate_data(ppo_config.batch_size) + queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers) + train_stats = ppo_trainer.step(queries, responses, rewards, masks) + + response_texts = [tokenizer.decode(response) for response in responses] + query_texts = [tokenizer.decode(query) for query in queries] + texts = {"query": [qt.split("")[-1].strip() for qt in query_texts], "response": response_texts} + ppo_trainer.log_stats(train_stats, texts, rewards) +ppo_trainer.save_pretrained(model_id + "-calculator") diff --git a/examples/research_projects/tools/python_interpreter.py b/examples/research_projects/tools/python_interpreter.py new file mode 100644 index 00000000000..ebde5840885 --- /dev/null +++ b/examples/research_projects/tools/python_interpreter.py @@ -0,0 +1,194 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +import torch +from datasets import load_dataset +from peft import LoraConfig +from transformers import AutoTokenizer, HfArgumentParser, load_tool + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment + + +os.environ["HF_ALLOW_CODE_EVAL"] = "1" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@dataclass +class ScriptArguments: + model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"}) + learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"}) + mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"}) + batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"}) + gradient_accumulation_steps: Optional[int] = field( + default=16, metadata={"help": "the number of gradient accumulation steps"} + ) + max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"}) + ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"}) + n_epochs: Optional[int] = field(default=32, metadata={"help": "max number of ppo epochs"}) + + +parser = HfArgumentParser(ScriptArguments) +args = parser.parse_args_into_dataclasses()[0] + + +def exact_match_reward(responses, answers=None): + """Reward if generated response contains correct answer.""" + rewards = [] + pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*" # generated by chatGPT + for response, answer in zip(responses, answers): + reward = 0.0 + try: + predicted_number = None + match_pattern = re.findall(pattern, response) + if match_pattern: + predicted_number = float(match_pattern[0]) + if predicted_number is not None: + if np.abs((predicted_number - float(answer))) < 0.1: + reward += 1.0 + except: # noqa + pass + rewards.append(torch.tensor(reward)) + return rewards + + +def evaluate(test_dataloader, text_env, ppo_trainer): + test_rewards = [] + for test_batch in test_dataloader: + _, _, _, rewards, _ = text_env.run(test_batch["query"], answers=test_batch["answer"]) + test_rewards.extend(rewards) + test_rewards = ppo_trainer.accelerator.gather_for_metrics( + torch.stack(test_rewards).to(ppo_trainer.accelerator.device) + ) + return test_rewards.mean() + + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + target_modules=["c_proj", "c_attn", "q_attn"], +) + +# set up models +model = AutoModelForCausalLMWithValueHead.from_pretrained( + args.model_name, + use_auth_token=True, + load_in_4bit=True, + peft_config=lora_config, +) +tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=True) +tokenizer.pad_token = tokenizer.eos_token + +ds = load_dataset("gsm8k", "main", split="train") +ds = ds.rename_columns({"question": "query"}) +ds = ds.map(lambda x: {"answer": x["answer"].split("#### ")[1]}) +ds = ds.select(range(1, len(ds))) # skip the first sample which is used in prompt + +ds_test = load_dataset("gsm8k", "main", split="test") +ds_test = ds_test.rename_columns({"question": "query"}) +ds_test = ds_test.map(lambda x: {"answer": x["answer"].split("#### ")[1]}) + +test_dataloader = torch.utils.data.DataLoader(ds_test, batch_size=args.batch_size) + +# prompt +prompt = """\ +Example of using a Python API to solve math questions. + +Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left? + + +def solution(): + money_initial = 23 + bagels = 5 + bagel_cost = 3 + money_spent = bagels * bagel_cost + money_left = money_initial - money_spent + result = money_left + return result +print(solution()) +72 + +Result = 72 + +Q: """ + +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "eos_token_id": -1, + "max_new_tokens": args.max_new_tokens, +} + +# trainer +ppo_config = PPOConfig( + batch_size=args.batch_size, + learning_rate=args.learning_rate, + mini_batch_size=args.mini_batch_size, + ppo_epochs=args.ppo_epochs, + gradient_accumulation_steps=args.gradient_accumulation_steps, + log_with="wandb", + tracker_project_name="trl-gsm8k", + remove_unused_columns=False, + optimize_cuda_cache=True, +) + +ppo_trainer = PPOTrainer(config=ppo_config, model=model, tokenizer=tokenizer, dataset=ds) +test_dataloader = ppo_trainer.accelerator.prepare(test_dataloader) + +# text env +text_env = TextEnvironment( + model, + tokenizer, + [load_tool("lvwerra/python-interpreter")], + exact_match_reward, + prompt, + max_turns=2, + generation_kwargs=generation_kwargs, +) + +# main training loop +for epoch in range(args.n_epochs): + for step, batch in enumerate(ppo_trainer.dataloader): + if (step == 0) and (epoch % 4 == 0): # evaluate every 4 epochs + reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer) + else: + reward_mean_test = None + + queries, responses, masks, rewards, histories = text_env.run(batch["query"], answers=batch["answer"]) + train_stats = ppo_trainer.step(queries, responses, rewards, masks) + + # logging + if reward_mean_test is not None: + train_stats["env/reward_mean_test"] = reward_mean_test + texts = { + "query": batch["query"], + "response": [tokenizer.decode(response) for response in responses], + "answer": batch["answer"], + } + ppo_trainer.log_stats(train_stats, texts, rewards) + +reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer) +ppo_trainer.save_pretrained(f"model/{args.model_name}-gsm8k") diff --git a/examples/research_projects/tools/triviaqa.py b/examples/research_projects/tools/triviaqa.py new file mode 100644 index 00000000000..51bca6a9954 --- /dev/null +++ b/examples/research_projects/tools/triviaqa.py @@ -0,0 +1,189 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass, field +from typing import Optional + +import torch +from datasets import load_dataset +from peft import LoraConfig +from transformers import AutoTokenizer, HfArgumentParser, load_tool + +from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment + + +os.environ["HF_ALLOW_CODE_EVAL"] = "1" +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +@dataclass +class ScriptArguments: + model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"}) + log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"}) + learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"}) + mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"}) + batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"}) + gradient_accumulation_steps: Optional[int] = field( + default=16, metadata={"help": "the number of gradient accumulation steps"} + ) + max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"}) + ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"}) + iterations: Optional[int] = field(default=1000, metadata={"help": "the number of iterations"}) + seed: Optional[int] = field(default=0, metadata={"help": "the random seed"}) + + +parser = HfArgumentParser(ScriptArguments) +args = parser.parse_args_into_dataclasses()[0] + +lora_config = LoraConfig( + r=16, + lora_alpha=32, + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + target_modules=["c_proj", "c_attn", "q_attn"], +) + +# set up models +model = AutoModelForCausalLMWithValueHead.from_pretrained( + args.model_name, + use_auth_token=True, + trust_remote_code=True, + load_in_4bit=True, + peft_config=lora_config, +) +tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=True) +tokenizer.pad_token = tokenizer.eos_token + +# system prompt +prompt = """\ +Answer the following question: + +Q: In which branch of the arts is Patricia Neary famous? +A: Ballets +A2: Patricia NearyPatricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe. +Result=Ballets + +Q: Who won Super Bowl XX? +A: Chicago Bears +A2: Super Bowl XXSuper Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46โ€“10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans. +Result=Chicago Bears + +Q: """ + +generation_kwargs = { + "min_length": -1, + "top_k": 0.0, + "top_p": 1.0, + "do_sample": True, + "pad_token_id": tokenizer.eos_token_id, + "eos_token_id": -1, + "max_new_tokens": args.max_new_tokens, +} + +# trainer +config = PPOConfig( + batch_size=args.batch_size, + model_name=args.model_name, + learning_rate=args.learning_rate, + log_with=args.log_with, + mini_batch_size=args.mini_batch_size, + ppo_epochs=args.ppo_epochs, + gradient_accumulation_steps=args.gradient_accumulation_steps, + seed=args.seed, + optimize_cuda_cache=True, +) +ppo_trainer = PPOTrainer(config=config, model=model, tokenizer=tokenizer) +dataset = load_dataset("trivia_qa", "rc", split="train") +local_seed = args.seed + ppo_trainer.accelerator.process_index * 100003 # Prime +dataset = dataset.shuffle(local_seed) + + +def data_generator(): + for i in range(len(dataset)): + yield dataset[i]["question"], [item for item in dataset[i]["answer"]["normalized_aliases"]] + + +gen = data_generator() +gen = iter(gen) + + +def generate_data(n): + tasks, answers = [], [] + for i in range(n): + q, a = next(gen) + tasks.append(q) + answers.append(a) + return tasks, answers + + +def exact_match_reward(responses, answers=None): + """Reward if generated response contains correct answer.""" + rewards = [] + for response, answer in zip(responses, answers): + reward = 0.0 + for a in answer: + if a.lower() in response.lower(): + reward += 1.0 + break + rewards.append(torch.tensor(reward)) + return rewards + + +# text env +tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc") +# limit the amount if tokens +tool_fn = lambda x: tool(x).split("\n")[1][:600] # noqa +text_env = TextEnvironment( + model, + tokenizer, + {"Wiki": tool_fn}, + exact_match_reward, + prompt, + generation_kwargs=generation_kwargs, + max_tool_reponse=400, +) + + +def print_trainable_parameters(model): + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}" + ) + + +print_trainable_parameters(model) +# main training loop +for i in range(args.iterations): + tasks, answers = generate_data(config.batch_size) + queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers) + train_stats = ppo_trainer.step(queries, responses, rewards, masks) + response_texts = [tokenizer.decode(response) for response in responses] + query_texts = [tokenizer.decode(query) for query in queries] + texts = { + "query": [qt.split("")[-1].strip() for qt in query_texts], + "response": response_texts, + "answer": [", ".join(item) for item in answers], + } + all_rewards = ppo_trainer.accelerator.gather(torch.tensor(rewards, device=ppo_trainer.accelerator.device)) + ppo_trainer.log_stats(train_stats, texts, [item for item in all_rewards]) + if i % 100 == 0: + ppo_trainer.save_pretrained(f"models/{args.model_name}_{args.seed}_{i}_triviaqa") diff --git a/tests/test_environments.py b/tests/test_environments.py new file mode 100644 index 00000000000..e31daab5ceb --- /dev/null +++ b/tests/test_environments.py @@ -0,0 +1,273 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch + +import torch +from transformers import AutoTokenizer + +from trl import AutoModelForCausalLMWithValueHead, TextEnvironment, TextHistory + + +class DummyTool: + def __call__(self, text): + return text + + +def dummy_generate(histories): + for i in range(len(histories)): + histories[i].append_segment("test", torch.tensor([1, 2, 3]), system=False) + return histories + + +class TextHistoryTest(unittest.TestCase): + def test_text_history_init(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + + history = TextHistory(text, tokens) + self.assertEqual(history.text, text) + self.assertTrue(torch.equal(history.tokens, tokens)) + self.assertTrue(torch.equal(history.token_masks, torch.zeros_like(tokens))) + + history = TextHistory(text, tokens, system=False) + self.assertTrue(torch.equal(history.token_masks, torch.ones_like(tokens))) + + def test_text_history_append_segment(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + + history = TextHistory(text, tokens) + history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]), system=False) + self.assertEqual(history.text, text + "General Kenobi!") + self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6]))) + self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1]))) + + history.append_segment("You are a bold one!", torch.tensor([7, 8, 9])) + self.assertEqual(history.text, text + "General Kenobi!" + "You are a bold one!") + self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]))) + self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1, 0, 0, 0]))) + + def test_text_history_complete(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + history = TextHistory(text, tokens) + history.complete() + self.assertTrue(history.completed) + self.assertFalse(history.truncated) + + history.complete(truncated=True) + self.assertTrue(history.completed) + self.assertTrue(history.truncated) + + def test_text_history_last_segment(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + history = TextHistory(text, tokens) + history.append_segment("General Kenobi!", torch.tensor([4, 5, 6])) + history.append_segment("You are a bold one!", torch.tensor([7, 8, 9])) + self.assertEqual(history.last_text_segment, "You are a bold one!") + + def test_text_history_split_query_response(self): + text = "Hello there!" + tokens = torch.tensor([1, 2, 3]) + history = TextHistory(text, tokens) + history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]), system=False) + history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]), system=True) + query, response, mask = history.split_query_response_tokens() + + self.assertTrue(torch.equal(query, torch.tensor([1, 2, 3]))) + self.assertTrue(torch.equal(response, torch.tensor([4, 5, 6, 7, 8, 9]))) + self.assertTrue(torch.equal(mask, torch.tensor([1, 1, 1, 0, 0, 0]))) + + +class TextEnvironmentTester(unittest.TestCase): + @classmethod + def setUpClass(cls): + # model_id + cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + + # get models and tokenizer + cls.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(cls.model_id) + cls.gpt2_tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + cls.gpt2_tokenizer.pad_token = cls.gpt2_tokenizer.eos_token + + def test_text_environment_setup(self): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + ) + self.assertEqual(env.prompt, "I am a prompt!\n") + self.assertEqual(list(env.tools.keys()), ["DummyTool"]) + self.assertTrue(isinstance(env.tools["DummyTool"], DummyTool)) + self.assertEqual(env.reward_fn("Hello there!"), 1) + + def test_text_environment_generate(self): + generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id} + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + generation_kwargs=generation_kwargs, + ) + + input_texts = ["this is a test", "this is another, longer test"] + + model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts] + + generations_batched = env._generate_batched(model_inputs, batch_size=2) + generations_batched = self.gpt2_tokenizer.batch_decode(generations_batched) + + generations_single = [env._generate_batched([inputs], batch_size=1)[0] for inputs in model_inputs] + generations_single = self.gpt2_tokenizer.batch_decode(generations_single) + + self.assertEqual(generations_single, generations_batched) + + def test_text_environment_tool_call_parsing(self): + string_valid = "Something something Hello there!" + string_invalid_request = "Something something Hello there!" + string_invalid_call = "Something something Hello there!" + string_invalid_tool = "Something something |Tool2|Hello there!" + string_invalid_random = "<>abcdefghijklm<>nopqrstuvwxyz<>" + + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools=[DummyTool()], + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + ) + tool, response = env.parse_tool_call(string_valid) + self.assertEqual(tool, "Tool1") + self.assertEqual(response, "Hello there!") + + tool, response = env.parse_tool_call(string_invalid_request) + self.assertEqual(tool, None) + self.assertEqual(response, None) + + tool, response = env.parse_tool_call(string_invalid_call) + self.assertEqual(tool, None) + self.assertEqual(response, None) + + tool, response = env.parse_tool_call(string_invalid_tool) + self.assertEqual(tool, None) + self.assertEqual(response, None) + + tool, response = env.parse_tool_call(string_invalid_random) + self.assertEqual(tool, None) + self.assertEqual(response, None) + + def test_text_environment_tool_truncation(self): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"dummy": lambda x: "a" * 1000}, + reward_fn=lambda x: torch.tensor(1), + prompt="I am a prompt!\n", + ) + + env.max_tool_response = 100 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual(len(history.last_text_segment) - len(env.response_token), 100) + + env.max_tool_response = 500 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual(len(history.last_text_segment) - len(env.response_token), 500) + + env.max_tool_response = 1001 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual(len(history.last_text_segment) - len(env.response_token), 1000) + + env.max_tool_response = 2000 + history = env.step(TextHistory("Hello there!", torch.tensor([1, 2, 3]))) + self.assertEqual(len(history.last_text_segment) - len(env.response_token), 1000) + + @patch.object(TextEnvironment, "generate", side_effect=dummy_generate) + def test_text_environment_max_calls(self, mock_generate): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"DummyTool": DummyTool()}, + reward_fn=lambda x: [torch.tensor(1) for _ in x], + prompt="I am a prompt!\n", + ) + + env.max_turns = 1 + _, _, _, _, histories = env.run(["test"]) + self.assertEqual( + histories[0].text, "I am a prompt!\n" + "test" + 1 * "testtest" + ) + + env.max_turns = 2 + _, _, _, _, histories = env.run(["test"]) + self.assertEqual( + histories[0].text, "I am a prompt!\n" + "test" + 2 * "testtest" + ) + + env.max_turns = 4 + _, _, _, _, histories = env.run(["test"]) + self.assertEqual( + histories[0].text, "I am a prompt!\n" + "test" + 4 * "testtest" + ) + + def test_text_environment_compute_rewards(self): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"DummyTool": DummyTool()}, + reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)], + prompt="I am a prompt!\n", + ) + + histories = [TextHistory("test", torch.tensor([1, 2, 3])) for _ in range(8)] + histories = env.compute_reward(histories) + + for i in range(8): + self.assertEqual(histories[i].reward, i) + + @patch.object(TextEnvironment, "generate", side_effect=dummy_generate) + def test_text_environment_run(self, mock_generate): + env = TextEnvironment( + self.gpt2_model, + self.gpt2_tokenizer, + tools={"DummyTool": DummyTool()}, + reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)], + prompt="I am a prompt!\n", + max_turns=2, + ) + task_1 = "Hello there!" + task_2 = "Hello there! General Kenobi!" + + query, response, response_mask, reward, histories = env.run([task_1, task_2]) + self.assertEqual(len(query[0]), 9) + self.assertEqual(len(query[1]), 12) + self.assertEqual(len(response[0]), 14) + self.assertEqual(len(response[1]), 14) + self.assertEqual(response_mask[0].sum(), 2 * 3) # mocked generate always adds 3 toknes + self.assertEqual(response_mask[1].sum(), 2 * 3) # mocked generate always adds 3 toknes + self.assertEqual(reward[0], 0) + self.assertEqual(reward[1], 1) + self.assertEqual( + histories[0].text, "I am a prompt!\n" + "Hello there!" + 2 * "testtest" + ) + self.assertEqual( + histories[1].text, + "I am a prompt!\n" + "Hello there! General Kenobi!" + 2 * "testtest", + ) diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index 6c30ea1812c..a3a023b32f0 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -209,6 +209,38 @@ def test_ppo_step(self): for stat in EXPECTED_STATS: assert stat in train_stats.keys() + def test_ppo_step_with_masks(self): + # initialize dataset + dummy_dataset = self._init_dummy_dataset() + + ppo_trainer = PPOTrainer( + config=self.ppo_config, + model=self.gpt2_model, + ref_model=self.gpt2_model_ref, + tokenizer=self.gpt2_tokenizer, + dataset=dummy_dataset, + ) + dummy_dataloader = ppo_trainer.dataloader + # train model with ppo + for query_tensor, response_tensor in dummy_dataloader: + # define a reward for response + # (this could be any reward such as human feedback or output from another model) + reward = [torch.tensor(1.0), torch.tensor(0.0)] + + response_mask = [torch.ones_like(r) for r in response_tensor] + + # train model + train_stats = ppo_trainer.step( + [q for q in query_tensor], [r for r in response_tensor], reward, response_mask + ) + break + + for param in ppo_trainer.model.parameters(): + assert param.grad is not None + + for stat in EXPECTED_STATS: + assert stat in train_stats.keys() + def test_ppo_step_with_no_ref_sgd(self): # initialize dataset dummy_dataset = self._init_dummy_dataset() @@ -466,7 +498,7 @@ def test_ppo_step_input_shape(self): # train model - this should raise an error bs = ppo_trainer.config.batch_size - queries, responses, _ = ppo_trainer._step_safety_checker( + queries, responses, _, _ = ppo_trainer._step_safety_checker( bs, [q for q in query_tensor], [r for r in response_tensor], reward ) diff --git a/trl/__init__.py b/trl/__init__.py index 0930f22aa12..5ca2028a789 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -3,6 +3,7 @@ __version__ = "0.6.1.dev0" from .core import set_seed +from .environment import TextEnvironment, TextHistory from .extras import BestOfNSampler from .import_utils import is_diffusers_available, is_peft_available from .models import ( diff --git a/trl/environment/__init__.py b/trl/environment/__init__.py new file mode 100644 index 00000000000..ae1cda4ecb2 --- /dev/null +++ b/trl/environment/__init__.py @@ -0,0 +1,3 @@ +# flake8: noqa + +from .base_environment import TextEnvironment, TextHistory diff --git a/trl/environment/base_environment.py b/trl/environment/base_environment.py new file mode 100644 index 00000000000..25f44ae9355 --- /dev/null +++ b/trl/environment/base_environment.py @@ -0,0 +1,473 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import warnings + +import torch +from accelerate.utils import extract_model_from_parallel +from transformers import StoppingCriteria, StoppingCriteriaList + +from ..import_utils import is_rich_available + + +if is_rich_available(): + from rich import print + from rich.text import Text + + +class StringStoppingCriteria(StoppingCriteria): + """Custom `StoppingCriteria` which checks if all generations in the batch are completed.""" + + def __init__(self, stop_strings, tokenizer): + self.stop_strings = stop_strings + self.tokenizer = tokenizer + self.first_call = True + + def __call__(self, input_ids, scores, **kwargs): + """Returns true if all generated sequences contain any of the stop strings.""" + if self.first_call: + self.generated_tokens = [1 for _ in range(input_ids.shape[0])] + self.start_length = input_ids.shape[-1] - 1 + self.first_call = False + decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :]) + done = [] + + for i, decoded_generation in enumerate(decoded_generations): + sequence_complete = any([stop_string in decoded_generation for stop_string in self.stop_strings]) + done.append(sequence_complete) + if not sequence_complete: + self.generated_tokens[i] += 1 + + if all(done): + self.first_call = True + + return all(done) + + +class TextHistory: + """The TextHistory class keeps track of the history of an interaction between the language model and the environment.""" + + def __init__(self, text, tokens, system=True): + """ + Initialize TextHistory. + + args: + text (`str`): The text of the first segment. + tokens (`torch.LongTensor`): The tokens of the first segment. + system (`bool`, *optional*): Whether the first segment is a system or user segment. + """ + self.system_spans = [] + self.text_spans = [] + self.token_spans = [] + self.token_masks = torch.tensor([], dtype=torch.long).to(tokens.device) + self.text = "" + self.tokens = torch.tensor([], dtype=torch.long).to(tokens.device) + self.completed = False + self.truncated = False + self.reward = 0.0 + + self.prompt_color = "black on grey85" + self.system_color = "black on cyan3" + self.model_color = "black on deep_sky_blue1" + self.reward_color = "black on plum1" + + self.append_segment(text, tokens, system=system) + + def append_segment(self, text, tokens, system=True): + """ + Append a new segment to the history. + + args: + text (`str`): The text of the new segment. + tokens (`torch.LongTensor`): The tokens of the new segment. + system (`bool`, *optional*): Whether the new segment is a system or user segment. + """ + + if len(text) == 0 or len(tokens) == 0: + raise ValueError("Can't append empty text or token list to history.") + + original_text_length = len(self.text) + + self.text += text + self.text_spans.append((original_text_length, len(self.text))) + self.system_spans.append(system) + + original_token_length = len(self.tokens) + + self.tokens = torch.cat((self.tokens, tokens)) + if system: + self.token_masks = torch.cat((self.token_masks, torch.zeros_like(tokens))) + else: + self.token_masks = torch.cat((self.token_masks, torch.ones_like(tokens))) + self.token_spans.append((original_token_length, len(self.tokens))) + + def complete(self, truncated=False): + """ + Mark the history as completed. + """ + self.completed = True + self.truncated = truncated + + @property + def last_text_segment(self): + """ + Get the last text segment. + """ + start, end = self.text_spans[-1] + return self.text[start:end] + + def split_query_response_tokens(self): + """ + Split the tokens into query and response tokens. + """ + split_index = self.token_spans[0][1] + query = self.tokens[:split_index] + response = self.tokens[split_index:] + mask = self.token_masks[split_index:] + + return query, response, mask + + def show_text(self, show_legend=False): + """ + Print the text history. + """ + if not is_rich_available(): + warnings.warn("install rich to display text") + return + + text = Text(self.text) + text.stylize(self.prompt_color, self.text_spans[0][0], self.text_spans[1][0]) + for i, (start, end) in enumerate(self.text_spans[1:]): + if self.system_spans[i + 1]: + text.stylize(self.system_color, start, end) + else: + text.stylize(self.model_color, start, end) + + text.append(f"\n\nReward: {self.reward}", style=self.reward_color) + print(text) + + if show_legend: + self.show_colour_legend() + + def show_tokens(self, tokenizer, show_legend=False): + """ + Print the history tokens. + """ + if not is_rich_available(): + warnings.warn("install rich to display tokens") + return + + text = Text() + prompt_end = self.token_spans[0][1] + for i, (token, mask) in enumerate(zip(self.tokens, self.token_masks)): + if i < prompt_end: + text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.prompt_color) + text.append(" ") + elif mask == 0: + text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.system_color) + text.append(" ") + else: + text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.model_color) + text.append(" ") + text.append(f"\n\nReward: {self.reward}", style=self.reward_color) + print(text) + if show_legend: + self.show_colour_legend() + + def show_colour_legend(self): + """ + Print the colour legend. + """ + if not is_rich_available(): + warnings.warn("install rich to display colour legend") + return + text = Text("\n\n(Colour Legend: ") + text.append("Prompt", style=self.prompt_color) + text.append("|") + text.append("System", style=self.system_color) + text.append("|") + text.append("Model", style=self.model_color) + text.append("|") + text.append("Reward", style=self.reward_color) + text.append(")") + print(text) + + +class TextEnvironment: + """ + The TextEnvironment enables interaction of a LLM with an environment using tools. + """ + + def __init__( + self, + model=None, + tokenizer=None, + tools=None, + reward_fn=None, + prompt=None, + max_turns=4, + max_tool_reponse=100, + max_length=None, + generation_kwargs=None, + ): + """ + Initialize TextEnvironment. + + Args: + model (`PreTrainedModelWrapper`): The model to use for generation. + tokenizer (`transformers.PreTrainedTokenizer`): The tokenizer to use for generation. + tools (list): A list of tools to use for interaction. + reward_fn (function): A function that takes a string and returns a reward. + prompt (str): The base prompt to use for generation. Is prepended to the tasks. + max_turns (Optional[int]): The maximum number of turns to allow. + max_tool_response (Optional[int]): The maximum number of characters to allow in a tool response. + max_length (Optional[int]): The maximum number of tokens to allow in an episode. + generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method. + """ + self.model = model + self.tokenizer = tokenizer + self.prompt = prompt + if isinstance(tools, dict): + self.tools = tools + else: + self.tools = dict([(tool.__class__.__name__, tool) for tool in tools]) + self.reward_fn = reward_fn + self.max_length = max_length + self.request_token = "" + self.call_token = "" + self.response_token = "" + self.submit_token = "" + self.max_turns = max_turns + self.max_tool_response = max_tool_reponse + + if generation_kwargs is None: + self.generation_kwargs = dict() + else: + self.generation_kwargs = generation_kwargs + + self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") + self.current_device = extract_model_from_parallel(self.model).pretrained_model.device + + def run(self, queries, **rewards_kwargs): + """ + Run the environment on a list of queries. + + Args: + queries (list[str]): A list of queries to run the model in the environment on. + """ + turns = 0 + + queries = [self.prompt + task for task in queries] + queries_tokens = [ + self.tokenizer(query, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device) + for query in queries + ] + + histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)] + + while any([not history.completed for history in histories]) and turns < self.max_turns: + histories = self.generate(histories) + histories = self.tasks_end_check(histories) + # TODO: make this parallel rather than for-loop + for i in range(len(histories)): + histories[i] = self.step(histories[i]) + histories = self.tasks_end_check(histories, model_turn=False) + turns += 1 + self.compute_reward(histories, **rewards_kwargs) + + # convert a list of (q, r, m) tuples to lists of all qs, rs, and ms respectively + queries, responses, masks = map(list, zip(*[history.split_query_response_tokens() for history in histories])) + + rewards = [history.reward for history in histories] + return queries, responses, masks, rewards, histories + + def step(self, history): + """ + Step the environment forward one turn. + + Args: + history (`TextHistory`): The history to step forward. + """ + truncated, ended = self.task_end_check(history) + if ended: + history.complete(truncated=truncated) + if history.completed: + return history + + tool, query = self.parse_tool_call(history.last_text_segment) + if tool is None or query is None: + response = f"Unknown tool call: {history.last_text_segment}" + else: + if tool not in self.tools: + response = f"Unknown tool {tool}." + try: + response = self.tools[tool](query) + except Exception as error: + response = f"Tool error: {str(error)}" + + if len(response) > self.max_tool_response: + response = response[: (self.max_tool_response - 3)] + "..." + + history.append_segment( + response + self.response_token, + self.tokenizer(response + self.response_token, return_tensors="pt") + .input_ids[0] + .to(self.model.pretrained_model.device), + system=True, + ) + + return history + + def parse_tool_call(self, text): + """ + Parse request string. Expected format: query + """ + result = re.search(f"(?<={self.request_token}).*?(?={self.call_token})", text, re.DOTALL) + + # if we can't find a / span we return none + if result is None: + return None, None + else: + extracted_text = result.group() + + result = re.search(r"<(.*?)>", extracted_text) + + # if we can't find a tool name we return none + if result is None: + return None, None + else: + tool = result.group(1) + + # split off the tool name + query = ">".join(extracted_text.split(">")[1:]) + + return tool, query + + def compute_reward(self, histories, **reward_kwargs): + """ + Compute the reward for a list of histories. + """ + rewards = self.reward_fn([history.last_text_segment for history in histories], **reward_kwargs) + for history, reward in zip(histories, rewards): + history.reward = reward + return histories + + def generate(self, histories): + """ + Generate responses for a list of histories. + """ + active_histories = [i for i, history in enumerate(histories) if not history.completed] + + query_tensors = [histories[i].tokens for i in active_histories] + response_tensors = self._generate_batched(query_tensors) + response_texts = self.tokenizer.batch_decode(response_tensors) + + for i, response_text, response_tensor in zip(active_histories, response_texts, response_tensors): + histories[i].append_segment(response_text, response_tensor, system=False) + + return histories + + def tasks_end_check(self, histories, model_turn=True): + """ + Check if the current generation sequences have finished. + """ + for history in histories: + if not history.completed: + truncated, ended = self.task_end_check(history, model_turn=model_turn) + if ended: + history.complete(truncated=truncated) + return histories + + def task_end_check(self, history, model_turn=True): + """ + Check if the current generation sequence has finished. + """ + truncated = False + ended = False + if history.completed: + return truncated, ended + if self.max_length is not None and len(self.tokenizer(history.text).input_ids[0]) > self.max_length: + truncated = True + ended = True + elif self.tokenizer.eos_token in history.text: + ended = True + elif model_turn and not ( + (self.request_token in history.last_text_segment and self.call_token in history.last_text_segment) + or self.submit_token in history.last_text_segment + ): + ended = True + elif self.submit_token in history.last_text_segment: + ended = True + return truncated, ended + + def _generate_batched( + self, + query_tensors, + batch_size: int = 16, + pad_to_multiple_of: int = None, + ): + """ + Generate responses for a list of query tensors. + + args: + query_tensors (list[torch.Tensor]): A list of query tensors to generate responses for. + batch_size (int): The batch size to use for generation. + pad_to_multiple_of (int): The padding length to use for generation. + """ + outputs = [] + padding_side_default = self.tokenizer.padding_side + if not self.is_encoder_decoder: + self.tokenizer.padding_side = "left" + + # in case we have fewer examples than bs + batch_size = min(len(query_tensors), batch_size) + + for i in range(0, len(query_tensors), batch_size): + # prevent overflow if query tensors are not even multiple of bs + end_index = min(len(query_tensors), i + batch_size) + + batch = query_tensors[i:end_index] + batch_mask = [torch.ones_like(element) for element in batch] + inputs = {"input_ids": batch, "attention_mask": batch_mask} + + padded_inputs = self.tokenizer.pad( + inputs, + padding=True, + max_length=None, + pad_to_multiple_of=pad_to_multiple_of, + return_tensors="pt", + ).to(self.current_device) + + stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer) + + self.generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria]) + + generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs) + + for generation, mask, generated_tokens in zip( + generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens + ): + if not self.is_encoder_decoder: + output = generation[(1 - mask).sum() :] # remove padding + else: + output = generation + + if not self.is_encoder_decoder: + output = output[(mask).sum() :] # remove prompt + + # remove chunk generated after stopping criteria in batch mode + outputs.append(output[:generated_tokens]) + self.tokenizer.padding_side = padding_side_default + return outputs diff --git a/trl/import_utils.py b/trl/import_utils.py index bc631093371..6dca3c66e2e 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -47,3 +47,7 @@ def is_bitsandbytes_available(): def is_torchvision_available(): return importlib.util.find_spec("torchvision") is not None + + +def is_rich_available(): + return importlib.util.find_spec("rich") is not None diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index e50682cf18c..6992027bce3 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -207,6 +207,7 @@ def __init__( self.model_params = filter(lambda p: p.requires_grad, self.model.parameters()) self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder") self.is_peft_model = getattr(self.model, "is_peft_model", False) + self.is_using_text_environment = getattr(config, "use_text_environment", False) if isinstance(ref_model, SUPPORTED_ARCHITECTURES): self.ref_model = ref_model @@ -526,6 +527,7 @@ def _step_safety_checker( queries: List[torch.LongTensor], responses: List[torch.LongTensor], scores: List[torch.FloatTensor], + masks: Optional[List[torch.LongTensor]] = None, ): """ Check if the input data is valid for training. @@ -539,6 +541,8 @@ def _step_safety_checker( List of tensors containing the encoded responses of shape (`response_length`) scores (List[`torch.FloatTensor`]): List of tensors containing the scores. + masks (List[`torch.LongTensor`], *optional*): + list of optional tensors containing the masks of shape (`query_length` + `response_length`) Returns: `tuple`: The input processed data. """ @@ -556,6 +560,7 @@ def _step_safety_checker( queries = [tensor.to(self.current_device) for tensor in queries] responses = [tensor.to(self.current_device) for tensor in responses] scores = [tensor.to(self.current_device) for tensor in scores] + masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None # squeeze scores if needed for i, score in enumerate(scores): @@ -564,7 +569,7 @@ def _step_safety_checker( elif score.dim() == 1: scores[i] = score.squeeze() - return queries, responses, scores + return queries, responses, scores, masks @PPODecorators.empty_cuda_cache() def step( @@ -572,6 +577,7 @@ def step( queries: List[torch.LongTensor], responses: List[torch.LongTensor], scores: List[torch.FloatTensor], + response_masks: Optional[List[torch.LongTensor]] = None, ): """ Run a PPO optimisation step given a list of queries, model responses, and rewards. @@ -583,13 +589,17 @@ def step( List of tensors containing the encoded responses of shape (`response_length`) scores (List[`torch.FloatTensor`]): List of tensors containing the scores. + response_masks (List[`torch.FloatTensor`], *optional*)): + List of tensors containing masks of the response tokens. Returns: `dict[str, Any]`: A summary of the training statistics """ bs = self.config.batch_size - queries, responses, scores = self._step_safety_checker(bs, queries, responses, scores) + queries, responses, scores, response_masks = self._step_safety_checker( + bs, queries, responses, scores, response_masks + ) scores = torch.tensor(scores) if self.config.use_score_scaling: # Score scaling @@ -654,9 +664,13 @@ def step( with torch.no_grad(): all_logprobs, logits_or_none, values, masks = self.batched_forward_pass( - self.model, queries, responses, model_inputs, return_logits=full_kl_penalty + self.model, + queries, + responses, + model_inputs, + response_masks=response_masks, + return_logits=full_kl_penalty, ) - # for when the model is a peft model if self.is_peft_model and hasattr( self.accelerator.unwrap_model(self.model).pretrained_model, @@ -879,7 +893,6 @@ def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor): input_data["decoder_input_ids"] = decoder_inputs["input_ids"] input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"] - else: input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)] input_data = self.data_collator( @@ -887,7 +900,6 @@ def prepare_model_inputs(self, queries: torch.Tensor, responses: torch.Tensor): ).to(self.current_device) input_data.pop("labels", None) # we don't want to compute LM losses - return input_data @PPODecorators.empty_cuda_cache() @@ -898,6 +910,7 @@ def batched_forward_pass( responses: torch.Tensor, model_inputs: dict, return_logits: bool = False, + response_masks: Optional[torch.Tensor] = None, ): """ Calculate model outputs in multiple batches. @@ -928,6 +941,8 @@ def batched_forward_pass( input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} query_batch = queries[i * fbs : (i + 1) * fbs] response_batch = responses[i * fbs : (i + 1) * fbs] + if response_masks is not None: + response_masks_batch = response_masks[i * fbs : (i + 1) * fbs] logits, _, values = model(**input_kwargs) if self.is_encoder_decoder: @@ -951,9 +966,15 @@ def batched_forward_pass( if attention_mask[j, 0] == 0: # offset left padding start += attention_mask[j, :].nonzero()[0] end = start + len(response_batch[j]) + if response_masks is not None: + response_masks_batch[j] = torch.cat( + (torch.zeros_like(query_batch[j]), response_masks_batch[j]) + )[1:] masks[j, :start] = 0 masks[j, end:] = 0 + if response_masks is not None: + masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end] if return_logits: all_logits.append(logits) @@ -1274,8 +1295,12 @@ def log_stats( elif self.config.log_with == "wandb": import wandb - table_rows = [list(r) for r in zip(batch["query"], batch["response"], rewards.cpu().tolist())] - logs.update({"game_log": wandb.Table(columns=["query", "response", "reward"], rows=table_rows)}) + table_rows = [ + list(r) for r in zip(batch["query"], batch["response"], batch["answer"], rewards.cpu().tolist()) + ] + logs.update( + {"game_log": wandb.Table(columns=["query", "response", "answer", "reward"], rows=table_rows)} + ) # All reduce rewards if distributed if self.is_distributed: import torch.distributed as dist @@ -1340,3 +1365,18 @@ def _save_pretrained(self, save_directory: str) -> None: self.accelerator.unwrap_model(self.model).save_pretrained(save_directory) self.tokenizer.save_pretrained(save_directory) self.create_model_card(save_directory) + + def _show_tokens(self, tokens, masks): + from rich import print + from rich.text import Text + + text = Text() + + for i, (token, mask) in enumerate(zip(tokens, masks)): + if mask == 1: + text.append(self.tokenizer.decode(token.item()), style="black on deep_sky_blue1") + text.append(" ") + else: + text.append(self.tokenizer.decode(token.item()), style="black on cyan3") + text.append(" ") + print(text)