Skip to content
forked from CarperAI/trlx

[Added T5 support to TRLX] A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)

License

Notifications You must be signed in to change notification settings

CG80499/trlx-with-T5

 
 

Repository files navigation

Transformer Reinforcement Learning X

trlX allows you to fine-tune 🤗 Hugging Face supported language models (gpt2, gpt-j, gpt-neo and gpt-neox based) up to 20B parameters using reinforcement learning via either a provided reward function or reward-labeled dataset. Proximal Policy Optimization (PPO) and Implicit Language Q-Learning (ILQL) are implemented.

You can read more about trlX in our documentation.

Installation

git clone https://github.com/CarperAI/trlx.git
cd trlx
pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 # for cuda
pip install -e .

How to Train

You can train a model using a reward function or a reward-labeled dataset.

Using a reward function

model = trlx.train('gpt2', reward_fn=lambda samples: [sample.count('cats') for sample in samples])

Using a reward-labeled dataset

model = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)])

Trained model is a wrapper over a given autoregressive model

model.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True)

Use 🤗 Accelerate to launch distributed training

accelerate config # choose DeepSpeed option
accelerate launch examples/simulacra.py

Use Ray Tune to launch hyperparameter sweep

python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py

For more usage see examples

Contributing

For development check out these guidelines and also read our docs

Acknowledgements

Many thanks to Leandro von Werra for contributing with trl, a library that initially inspired this repo.

About

[Added T5 support to TRLX] A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF)

Resources

License

Code of conduct

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 100.0%