diff --git a/docs/source/customization.mdx b/docs/source/customization.mdx index b75d098709..69a924ec9b 100644 --- a/docs/source/customization.mdx +++ b/docs/source/customization.mdx @@ -38,8 +38,8 @@ from transformers import GPT2Tokenizer from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead # 1. load a pretrained model -model = AutoModelWithValueModel.from_pretrained('gpt2') -model_ref = AutoModelWithValueModel.from_pretrained('gpt2') +model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') +model_ref = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2') tokenizer = GPT2Tokenizer.from_pretrained('gpt2') # 2. define config @@ -119,4 +119,4 @@ tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m') ppo_config = {'batch_size': 1, 'forward_batch_size': 1} config = PPOConfig(**ppo_config) ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer) -``` \ No newline at end of file +```