Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TextEnvironments #424

Merged
merged 81 commits into from
Aug 30, 2023
Merged

TextEnvironments #424

merged 81 commits into from
Aug 30, 2023

Conversation

lvwerra
Copy link
Member

@lvwerra lvwerra commented Jun 9, 2023

This PR adds multi-turn text environment to TRL.

Target API

env = Environment(model, tokenizer, prompt, tools)

for tasks in ppo_trainer.dataloader:
    histories = env.run(tasks)
    tokens, mask = histories.get_tokens()
    ppo_trainer.step(tokens, mask=mask, histories.rewards)
    
# alternatively, this would probably be much nicer
for tasks in ppo_trainer.dataloader:
    tokens, masks, rewards, history = env.run(tasks)
    ppo_trainer.step(tokens, masks, rewards)

Todos

Current working example

from trl import TextEnvironment, TextHistory, AutoModelForCausalLMWithValueHead
from transformers import AutoModelForCausalLM, AutoTokenizer, load_tool

tool = load_tool("ybelkada/simple-calculator")

model_id = "gpt2-xl"

model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

prompt = """\
What is 12.1 + 1 - 3?
<request><SimpleCalculatorTool>12.1 + 1<call>13.1<response>
<request><SimpleCalculatorTool>13.1 - 3<call>10.1<response>
Result = 10.1 <submit>

"""

reward_fn = lambda x: 1

env = TextEnvironment(model, tokenizer,[tool], reward_fn, prompt, generation_kwargs={"max_new_tokens": 32})
h = env.run(["What is 387 * 228?"])

h[0].show()

Result:

What is 12.1 + 1 - 3?
<request><SimpleCalculatorTool>12.1 + 1<call>13.1<response>
<request><SimpleCalculatorTool>13.1 - 3<call>10.1<response>
Result = 10.1 <submit>

What is 387 * 228?

<request><SimpleCalculatorTool>387 * 228<call>88236.0<response>

Result = 88236.0 <submit>
Reward: 1

Where the 88236.0<response> segement was generated by the tool call.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 9, 2023

The documentation is not available anymore as the PR was closed or merged.

trl/environment/base.py Outdated Show resolved Hide resolved
trl/environment/base.py Outdated Show resolved Hide resolved
trl/environment/base.py Outdated Show resolved Hide resolved
younesbelkada and others added 20 commits June 23, 2023 11:31
* add v1 masking

* working v1

* adapt from suggestion

* avoid warning `Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.`

* fix masking

- mask the responses from API call only

* quality

* address comments

* Update trl/environment/base.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* adapt a bit

* wip on tokenization/masking in textenv

* small fixes

* update viz

* add example

* print debug text and pass masks

* style

* format and move tensor to device

* update example

* update example

* This seems to work

* fix masking

* fix rich output to console

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: leandro <leandro.vonwerra@spoud.io>
* add v1 masking

* working v1

* adapt from suggestion

* avoid warning `Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.`

* fix masking

- mask the responses from API call only

* quality

* address comments

* Update trl/environment/base.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* adapt a bit

* wip on tokenization/masking in textenv

* small fixes

* update viz

* add example

* print debug text and pass masks

* style

* format and move tensor to device

* update example

* update example

* This seems to work

* fix masking

* fix rich output to console

* fix batched generation

* improve stopping criteria

* improve error handling in tool call

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Costa Huang <costa.huang@outlook.com>
vwxyzjn and others added 7 commits August 28, 2023 15:24
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
@vwxyzjn
Copy link
Contributor

vwxyzjn commented Aug 28, 2023

Thanks @younesbelkada for the comments, I have addressed most of them. @lvwerra the PR looks good. Quick question: do we want to merge docs/source/learning_tools.mdx and docs/source/text_environments.md?

@lvwerra
Copy link
Member Author

lvwerra commented Aug 29, 2023

Quick question: do we want to merge docs/source/learning_tools.mdx and docs/source/text_environments.md?

I thought we could have text_environments.md as the basic doc for how TextEnvs work and the learning_tools.mdx as the more hands-on guide also linking the experiment scripts. Wdyt?

Also @vwxyzjn would you mind fixing the scripts so they pass the quality checks. In my opinion we can also exclude them from the quality tests, no strong opinion.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great ! Thanks for all your efforts, left 4 nits, otherwise LGTM !

examples/research_projects/tools/triviaqa.py Show resolved Hide resolved
trl/trainer/ppo_trainer.py Outdated Show resolved Hide resolved
@younesbelkada younesbelkada merged commit 9d09b3e into main Aug 30, 2023
@younesbelkada younesbelkada deleted the envs branch August 30, 2023 09:44
kushal-tri pushed a commit to kushalarora/trl that referenced this pull request Sep 19, 2023
* WIP skeleton

* minimal working poc

* cleanup

* rename variables

* quick typo fix

* add v1 masking (huggingface#429)

* add v1 masking

* working v1

* adapt from suggestion

* avoid warning `Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.`

* fix masking

- mask the responses from API call only

* quality

* address comments

* Update trl/environment/base.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* adapt a bit

* wip on tokenization/masking in textenv

* small fixes

* update viz

* add example

* print debug text and pass masks

* style

* format and move tensor to device

* update example

* update example

* This seems to work

* fix masking

* fix rich output to console

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: leandro <leandro.vonwerra@spoud.io>

* Add masking (huggingface#461)

* add v1 masking

* working v1

* adapt from suggestion

* avoid warning `Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.`

* fix masking

- mask the responses from API call only

* quality

* address comments

* Update trl/environment/base.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* adapt a bit

* wip on tokenization/masking in textenv

* small fixes

* update viz

* add example

* print debug text and pass masks

* style

* format and move tensor to device

* update example

* update example

* This seems to work

* fix masking

* fix rich output to console

* fix batched generation

* improve stopping criteria

* improve error handling in tool call

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Costa Huang <costa.huang@outlook.com>

* fix uknown tool

* fix rewards and increase bs

* remove unused script

* ugly WIP fix

* do not return modified obj for in-place operations

* do not return modified obj for in-place operations

* clean up stopping criterium

* push updates

* push update

* format, add docs

* rename file

* add kwargs to reward fn

* simplify example

* simplify example

* bug fix

* add a trivia example

* pre-commit

* max tool response length

* fix regex for multi-line

* refactor tool exceptions

* fix exceptions in tool

* add docs

* fix style

* make rich optional

* add docstrings

* add  tests

* add TextEnv tests (WIP)

* update triviaqa code

* update docs

* refactor text env

* update tests (WIP)

* add end2end test

* update docs

* upload tool demo

* refactor

* customizable system prompt

* add text env docs

* update index and toc

* fix `TextHistory` show methods

* add max length

* fix style

* fix typo

* refactor to kwargs in init and tasks to queries

* kwargs for reward docs

* Update examples/triviaqa.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update examples/tool_demo.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/learning_tools.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/learning_tools.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/learning_tools.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/text_environments.md

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update examples/triviaqa.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update examples/triviaqa.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* move to tool folder

* remove assets

* remove tool demo

* move rich import test to import utils

* add copyright

* fixes for masks in ppo trainer

* add text env api docs

* make precommit + add ppo test with mask

* move examples and add python

* fix style

* update triviaqa example

* add more docs

* update docs

* Update docs/source/learning_tools.mdx

* Apply suggestions from code review

* precommit

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: leandro von werra <leandro@hf.co>
lapp0 pushed a commit to lapp0/trl that referenced this pull request May 10, 2024
* WIP skeleton

* minimal working poc

* cleanup

* rename variables

* quick typo fix

* add v1 masking (huggingface#429)

* add v1 masking

* working v1

* adapt from suggestion

* avoid warning `Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.`

* fix masking

- mask the responses from API call only

* quality

* address comments

* Update trl/environment/base.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* adapt a bit

* wip on tokenization/masking in textenv

* small fixes

* update viz

* add example

* print debug text and pass masks

* style

* format and move tensor to device

* update example

* update example

* This seems to work

* fix masking

* fix rich output to console

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: leandro <leandro.vonwerra@spoud.io>

* Add masking (huggingface#461)

* add v1 masking

* working v1

* adapt from suggestion

* avoid warning `Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.`

* fix masking

- mask the responses from API call only

* quality

* address comments

* Update trl/environment/base.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* adapt a bit

* wip on tokenization/masking in textenv

* small fixes

* update viz

* add example

* print debug text and pass masks

* style

* format and move tensor to device

* update example

* update example

* This seems to work

* fix masking

* fix rich output to console

* fix batched generation

* improve stopping criteria

* improve error handling in tool call

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Costa Huang <costa.huang@outlook.com>

* fix uknown tool

* fix rewards and increase bs

* remove unused script

* ugly WIP fix

* do not return modified obj for in-place operations

* do not return modified obj for in-place operations

* clean up stopping criterium

* push updates

* push update

* format, add docs

* rename file

* add kwargs to reward fn

* simplify example

* simplify example

* bug fix

* add a trivia example

* pre-commit

* max tool response length

* fix regex for multi-line

* refactor tool exceptions

* fix exceptions in tool

* add docs

* fix style

* make rich optional

* add docstrings

* add  tests

* add TextEnv tests (WIP)

* update triviaqa code

* update docs

* refactor text env

* update tests (WIP)

* add end2end test

* update docs

* upload tool demo

* refactor

* customizable system prompt

* add text env docs

* update index and toc

* fix `TextHistory` show methods

* add max length

* fix style

* fix typo

* refactor to kwargs in init and tasks to queries

* kwargs for reward docs

* Update examples/triviaqa.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update examples/tool_demo.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/learning_tools.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/learning_tools.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/learning_tools.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/text_environments.md

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update examples/triviaqa.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update examples/triviaqa.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* move to tool folder

* remove assets

* remove tool demo

* move rich import test to import utils

* add copyright

* fixes for masks in ppo trainer

* add text env api docs

* make precommit + add ppo test with mask

* move examples and add python

* fix style

* update triviaqa example

* add more docs

* update docs

* Update docs/source/learning_tools.mdx

* Apply suggestions from code review

* precommit

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: leandro von werra <leandro@hf.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants