From dcd45d54520254c90c861ec84e21706d405dd884 Mon Sep 17 00:00:00 2001 From: Cong Date: Thu, 8 Jun 2023 22:43:37 +0800 Subject: [PATCH] feat: support add tokens to tokenizer. * Resize the model by-default * Adding special tokens is ignored by the decode phase of the PPO. This is because it needs to skip certain special tokens, such as EOS tokens. Therefore only add normal tokens. --- trlx/models/modeling_ppo.py | 13 +++++++++++++ trlx/trainer/__init__.py | 4 ++-- trlx/trainer/accelerate_base_trainer.py | 15 +++++++++------ trlx/trlx.py | 12 ++++++------ 4 files changed, 30 insertions(+), 14 deletions(-) diff --git a/trlx/models/modeling_ppo.py b/trlx/models/modeling_ppo.py index e3334f10b..31c6f73b1 100644 --- a/trlx/models/modeling_ppo.py +++ b/trlx/models/modeling_ppo.py @@ -411,6 +411,7 @@ def __init__( # The branch is defined by the last `num_layers_unfrozen` layers of the pretrained model decoder_blocks = deepcopy(hf_get_decoder_blocks(base_model)) + self.embed_tokens = base_model.get_input_embeddings() self.decoder_blocks = nn.ModuleList(list(decoder_blocks)[-num_layers_unfrozen:]) self.final_norm = deepcopy(hf_get_decoder_final_norm(base_model)) self.lm_head = deepcopy(hf_get_lm_head(base_model)) @@ -425,6 +426,18 @@ def __init__( for parameter in self.parameters(): parameter.requires_grad_(False) + def set_input_embeddings(self, value): + self.embed_tokens = value + + def get_input_embeddings(self): + return self.embed_tokens + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + class GPTModelBranch(ModelBranch): def forward( # noqa: max-complexity diff --git a/trlx/trainer/__init__.py b/trlx/trainer/__init__.py index 5b8f67ab2..bedf522df 100644 --- a/trlx/trainer/__init__.py +++ b/trlx/trainer/__init__.py @@ -41,7 +41,7 @@ def __init__( logit_mask=None, stop_sequences=None, train_mode=False, - additional_special_tokens=None, + additional_tokens=None, ): self.store: BaseRolloutStore = None self.config = config @@ -50,7 +50,7 @@ def __init__( self.train_mode = train_mode self.logit_mask = logit_mask self.stop_sequences = stop_sequences - self.additional_special_tokens = additional_special_tokens + self.additional_tokens = additional_tokens def push_to_store(self, data): self.store.push(data) diff --git a/trlx/trainer/accelerate_base_trainer.py b/trlx/trainer/accelerate_base_trainer.py index ca99a2248..be7718768 100644 --- a/trlx/trainer/accelerate_base_trainer.py +++ b/trlx/trainer/accelerate_base_trainer.py @@ -70,6 +70,15 @@ def __init__(self, config, **kwargs): # noqa: C901 self.scheduler = self.setup_scheduler() self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path) + self.tokenizer.add_tokens(self.additional_tokens) + # resize the model by-default + self.model.base_model.resize_token_embeddings(len(self.tokenizer)) + if hasattr(self.model, "frozen_head"): + self.model.frozen_head.resize_token_embeddings(len(self.tokenizer)) + else: + # resize a reference model when hydra heads are not used + self.ref_model.resize_token_embeddings(len(self.tokenizer)) + self.tokenizer.padding_side = config.tokenizer.padding_side self.tokenizer.truncation_side = config.tokenizer.truncation_side self.tokenizer.sep_token = "" @@ -77,12 +86,6 @@ def __init__(self, config, **kwargs): # noqa: C901 self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - if self.additional_special_tokens is not None and type(self.additional_special_tokens) is list: - self.tokenizer.add_special_tokens( - {"additional_special_tokens": self.additional_special_tokens} - ) - self.model.base_model.resize_token_embeddings(len(self.tokenizer)) - script_name = os.path.basename(sys.argv[0]).rsplit(".", 1)[0] if not isinstance(config.model.model_path, str): model_name = str(config.model.model_path).split()[0] diff --git a/trlx/trlx.py b/trlx/trlx.py index 20eb01654..2758f3b8a 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -1,6 +1,6 @@ import os import warnings -from typing import Callable, Dict, Iterable, List, Optional, Tuple +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union from trlx.data.configs import TRLConfig from trlx.data.default_configs import ( @@ -23,7 +23,7 @@ def train( # noqa: C901 metric_fn: Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]] = None, config: Optional[TRLConfig] = None, stop_sequences: Optional[List[str]] = [], - additional_special_tokens: Optional[List[str]] = None, + additional_tokens: Optional[Union[str, List[str]]] = None, ): """ Dispatches online, offline reinforcement training or supervised finetuning @@ -55,9 +55,9 @@ def train( # noqa: C901 stop_sequences (Optional[List[str]]): String sequences to trim generations (both for generating of experience and evaluation) up to its encounter in them. Generations will not contain them and also will also be right-stripped - additional_special_tokens (Optional[List[str]]): - A list of additional special tokens. Add them to the tokenizer to ensure they won’t be split by - the tokenization process. + additional_tokens (Optional[Union[str, List[str]]]): + A list of additional tokens. The given tokens are added only if they don’t already exist + in the vocabulary, each token then gets a new attributed id """ if config is None: warnings.warn( @@ -85,7 +85,7 @@ def train( # noqa: C901 reward_fn=reward_fn, metric_fn=metric_fn, stop_sequences=stop_sequences, - additional_special_tokens=additional_special_tokens, + additional_tokens=additional_tokens, **config.train.trainer_kwargs, )