-
Notifications
You must be signed in to change notification settings - Fork 473
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support add tokens to tokenizer. #498
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Cong, that's a nice QoL improvement! However, there is one minor issue with it, but I hope you can resolve it
self.tokenizer.add_special_tokens( | ||
{"additional_special_tokens": self.additional_special_tokens} | ||
) | ||
self.model.base_model.resize_token_embeddings(len(self.tokenizer)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To improve compatibility with other modified tokenizers, I think it would be great if resizing happened by default, regardless of this if condition. Also, for PPO, the reference model/head should be resized likewise, otherwise, this error occurs:
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:144: operator(): block: [0,0,0], thread: [93,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
Traceback (most recent call last):
File "/trlx/examples/ppo_sentiments.py", line 58, in <module>
main(hparams)
File "/trlx/examples/ppo_sentiments.py", line 47, in main
trlx.train(
File "/trlx/trlx/trlx.py", line 133, in train
trainer.learn()
File "/trlx/trlx/trainer/accelerate_base_trainer.py", line 506, in learn
self.prepare_learning()
File "trlx/trlx/trainer/accelerate_ppo_trainer.py", line 239, in prepare_learning
self.make_experience(self.config.method.num_rollouts)
File "/trlx/trlx/trainer/accelerate_ppo_trainer.py", line 427, in make_experience
ref_logprobs = logprobs_of_labels(ref_logits[:, :-1, :], all_tokens[:, 1:])
File "/trlx/trlx/utils/modeling.py", line 224, in logprobs_of_labels
logprobs_labels = torch.gather(logprobs, dim=-1, index=labels.unsqueeze(-1))
RuntimeError: CUDA error: device-side assert triggered
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your review, I will resolve it later.
My plan is:
- if hydra heads is used,
hasattr(self.model, "frozen_head")
, then I need to resize theself.model.frozen_head.decoder_blocks
, - if not, just resize the
self.ref_model
self.model.frozen_head.resize_token_embeddings(len(self.tokenizer)) | ||
else: | ||
# resize a reference model when hydra heads are not used | ||
self.ref_model.resize_token_embeddings(len(self.tokenizer)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when hydra heads are not used, ref_model
gets instantiated in AcceleratePPOTrainer, so maybe we can move this line there:
trlx/trlx/trainer/accelerate_ppo_trainer.py
Lines 71 to 74 in 404217b
if not hasattr(self.model, "frozen_head"): | |
self.ref_model = self.get_arch(self.config) | |
self.ref_model.to(self.accelerator.device) | |
self.ref_model.eval() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that's better.
dcd45d5
to
134bbf9
Compare
* Resize the model by-default * Adding special tokens is ignored by the decode phase of the PPO. This is because it needs to skip certain special tokens, such as EOS tokens. Therefore only add normal tokens.
move hydra heads and ref_model 's resize_token_embeddings function calls to AcceleratePPOTrainer
fd58c49
to
e7fc3e3
Compare
To improve the compatibility of various models initialized from different open-sourced models, people may want to add some tokens for better downstream tuning purposes.
For example, to improve our policy's adherence to our chat format, we may want to add ChatML tokens such as
"<|system|>", "<|assistant|>", "<|user|>", and "<|end|>"
to the policy tokenizer.Adding special tokens is ignored by the decode phase of the PPO. This is because it needs to skip certain special tokens, such as EOS tokens. Therefore, Will only add normal tokens.