diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index 51d54cf36..067f58b72 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -543,6 +543,18 @@ def __init__( for parameter in self.parameters(): parameter.requires_grad_(False) + def set_input_embeddings(self, value): + self.embed_tokens = value + + def get_input_embeddings(self): + return self.embed_tokens + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + class GPTModelBranch(ModelBranch): def forward( # noqa: max-complexity diff --git a/trlx/trainer/__init__.py b/trlx/trainer/__init__.py index 8e0d239df..bedf522df 100644 --- a/trlx/trainer/__init__.py +++ b/trlx/trainer/__init__.py @@ -41,6 +41,7 @@ def __init__( logit_mask=None, stop_sequences=None, train_mode=False, + additional_tokens=None, ): self.store: BaseRolloutStore = None self.config = config @@ -49,6 +50,7 @@ def __init__( self.train_mode = train_mode self.logit_mask = logit_mask self.stop_sequences = stop_sequences + self.additional_tokens = additional_tokens def push_to_store(self, data): self.store.push(data) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index 62c09fd0c..84aff4c93 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -68,6 +68,10 @@ def __init__(self, config, **kwargs): # noqa: C901 self.scheduler = self.setup_scheduler() self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path) + self.tokenizer.add_tokens(self.additional_tokens) + # resize the model by-default + self.model.base_model.resize_token_embeddings(len(self.tokenizer)) + self.tokenizer.padding_side = config.tokenizer.padding_side self.tokenizer.truncation_side = config.tokenizer.truncation_side self.tokenizer.sep_token = "" diff --git a/trlx/trainer/accelerate_ppo_trainer.py b/trlx/trainer/accelerate_ppo_trainer.py index a7fcbb447..97cd274e7 100644 --- a/trlx/trainer/accelerate_ppo_trainer.py +++ b/trlx/trainer/accelerate_ppo_trainer.py @@ -71,9 +71,15 @@ def __init__(self, config: TRLConfig, **kwargs): # Set up a reference model when hydra heads are not used if not hasattr(self.model, "frozen_head") and not self.model.peft_type: + # Full Reference Copy self.ref_model = self.get_arch(self.config) + self.ref_model.base_model.resize_token_embeddings(len(self.tokenizer)) self.ref_model.to(self.accelerator.device) self.ref_model.eval() + elif hasattr(self.model, "frozen_head") and self.model.frozen_head is not None: + # Hydra Reference: Use the frozen base layers and head as the reference model, resize hydra heads + self.model.frozen_head.resize_token_embeddings(len(self.tokenizer)) + # else PEFT Reference # Set up the KL controller # This helps prevent large divergences in the controller (policy) diff --git a/trlx/trlx.py b/trlx/trlx.py index 13ee5daaa..bf7bf50f9 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -1,6 +1,6 @@ import os import warnings -from typing import Callable, Dict, Iterable, List, Optional, Tuple +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union from trlx.data.configs import TRLConfig from trlx.data.default_configs import ( @@ -23,6 +23,7 @@ def train( # noqa: C901 metric_fn: Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]] = None, config: Optional[TRLConfig] = None, stop_sequences: Optional[List[str]] = [], + additional_tokens: Optional[Union[str, List[str]]] = None, ): """ Dispatches online, offline reinforcement training or supervised finetuning @@ -54,6 +55,9 @@ def train( # noqa: C901 stop_sequences (Optional[List[str]]): String sequences to trim generations (both for generating of experience and evaluation) up to its encounter in them. Generations will not contain them and also will also be right-stripped + additional_tokens (Optional[Union[str, List[str]]]): + A list of additional tokens. The given tokens are added only if they don’t already exist + in the vocabulary, each token then gets a new attributed id """ if config is None: warnings.warn( @@ -81,6 +85,7 @@ def train( # noqa: C901 reward_fn=reward_fn, metric_fn=metric_fn, stop_sequences=stop_sequences, + additional_tokens=additional_tokens, **config.train.trainer_kwargs, )