Skip to content

Commit

Permalink
feat: support add tokens to tokenizer.
Browse files Browse the repository at this point in the history
* Resize the model by-default
* Adding special tokens is ignored by the decode phase of the PPO. This is because it needs to skip certain special tokens, such as EOS tokens. Therefore only add normal tokens.
  • Loading branch information
congchan committed Jun 8, 2023
1 parent 74ea532 commit dcd45d5
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 14 deletions.
13 changes: 13 additions & 0 deletions trlx/models/modeling_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def __init__(

# The branch is defined by the last `num_layers_unfrozen` layers of the pretrained model
decoder_blocks = deepcopy(hf_get_decoder_blocks(base_model))
self.embed_tokens = base_model.get_input_embeddings()
self.decoder_blocks = nn.ModuleList(list(decoder_blocks)[-num_layers_unfrozen:])
self.final_norm = deepcopy(hf_get_decoder_final_norm(base_model))
self.lm_head = deepcopy(hf_get_lm_head(base_model))
Expand All @@ -425,6 +426,18 @@ def __init__(
for parameter in self.parameters():
parameter.requires_grad_(False)

def set_input_embeddings(self, value):
self.embed_tokens = value

def get_input_embeddings(self):
return self.embed_tokens

def get_output_embeddings(self):
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings


class GPTModelBranch(ModelBranch):
def forward( # noqa: max-complexity
Expand Down
4 changes: 2 additions & 2 deletions trlx/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
logit_mask=None,
stop_sequences=None,
train_mode=False,
additional_special_tokens=None,
additional_tokens=None,
):
self.store: BaseRolloutStore = None
self.config = config
Expand All @@ -50,7 +50,7 @@ def __init__(
self.train_mode = train_mode
self.logit_mask = logit_mask
self.stop_sequences = stop_sequences
self.additional_special_tokens = additional_special_tokens
self.additional_tokens = additional_tokens

def push_to_store(self, data):
self.store.push(data)
Expand Down
15 changes: 9 additions & 6 deletions trlx/trainer/accelerate_base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,22 @@ def __init__(self, config, **kwargs): # noqa: C901
self.scheduler = self.setup_scheduler()

self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path)
self.tokenizer.add_tokens(self.additional_tokens)
# resize the model by-default
self.model.base_model.resize_token_embeddings(len(self.tokenizer))
if hasattr(self.model, "frozen_head"):
self.model.frozen_head.resize_token_embeddings(len(self.tokenizer))
else:
# resize a reference model when hydra heads are not used
self.ref_model.resize_token_embeddings(len(self.tokenizer))

self.tokenizer.padding_side = config.tokenizer.padding_side
self.tokenizer.truncation_side = config.tokenizer.truncation_side
self.tokenizer.sep_token = "<sep>"
if config.model.model_arch_type != "seq2seq":
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id

if self.additional_special_tokens is not None and type(self.additional_special_tokens) is list:
self.tokenizer.add_special_tokens(
{"additional_special_tokens": self.additional_special_tokens}
)
self.model.base_model.resize_token_embeddings(len(self.tokenizer))

script_name = os.path.basename(sys.argv[0]).rsplit(".", 1)[0]
if not isinstance(config.model.model_path, str):
model_name = str(config.model.model_path).split()[0]
Expand Down
12 changes: 6 additions & 6 deletions trlx/trlx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import warnings
from typing import Callable, Dict, Iterable, List, Optional, Tuple
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

from trlx.data.configs import TRLConfig
from trlx.data.default_configs import (
Expand All @@ -23,7 +23,7 @@ def train( # noqa: C901
metric_fn: Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]] = None,
config: Optional[TRLConfig] = None,
stop_sequences: Optional[List[str]] = [],
additional_special_tokens: Optional[List[str]] = None,
additional_tokens: Optional[Union[str, List[str]]] = None,
):
"""
Dispatches online, offline reinforcement training or supervised finetuning
Expand Down Expand Up @@ -55,9 +55,9 @@ def train( # noqa: C901
stop_sequences (Optional[List[str]]):
String sequences to trim generations (both for generating of experience and evaluation) up to its
encounter in them. Generations will not contain them and also will also be right-stripped
additional_special_tokens (Optional[List[str]]):
A list of additional special tokens. Add them to the tokenizer to ensure they won’t be split by
the tokenization process.
additional_tokens (Optional[Union[str, List[str]]]):
A list of additional tokens. The given tokens are added only if they don’t already exist
in the vocabulary, each token then gets a new attributed id
"""
if config is None:
warnings.warn(
Expand Down Expand Up @@ -85,7 +85,7 @@ def train( # noqa: C901
reward_fn=reward_fn,
metric_fn=metric_fn,
stop_sequences=stop_sequences,
additional_special_tokens=additional_special_tokens,
additional_tokens=additional_tokens,
**config.train.trainer_kwargs,
)

Expand Down

0 comments on commit dcd45d5

Please sign in to comment.