Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generalized chat sft prompt #7655

Merged
merged 30 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c66de49
fix dataset issues
yidong72 Oct 5, 2023
3e42b74
Merge branch 'main' into sft_mcore
yidong72 Oct 5, 2023
1f3d2d3
working version
yidong72 Oct 6, 2023
7fdc339
all passed
yidong72 Oct 6, 2023
87a01bb
refactor tests
yidong72 Oct 6, 2023
9c0ee5c
all pass
yidong72 Oct 6, 2023
ccaa6a0
working version
yidong72 Oct 6, 2023
14467c4
use end name signal for labels
yidong72 Oct 6, 2023
4a674e4
all fixed
yidong72 Oct 6, 2023
11bc6cd
update doc
yidong72 Oct 6, 2023
2e0285a
style fix
yidong72 Oct 6, 2023
4ba2395
remove unused imports
yidong72 Oct 6, 2023
d1b8328
make sure nccl not timing out
yidong72 Oct 6, 2023
5bf546e
style fix
yidong72 Oct 6, 2023
f945ec6
Merge branch 'main' into sft_mcore
yidong72 Oct 6, 2023
cd7c77a
generate example template
yidong72 Oct 6, 2023
d734830
generic end of name token
yidong72 Oct 6, 2023
33b7910
style fix
yidong72 Oct 6, 2023
e293336
Merge branch 'sft_mcore' of github.com:NVIDIA/NeMo into sft_mcore
yidong72 Oct 6, 2023
c99b55f
add the chat prompt format into the config
yidong72 Oct 6, 2023
b64f0bd
make sure sft working
yidong72 Oct 6, 2023
86bb7b0
address reviewer comment
yidong72 Oct 6, 2023
019afa4
Merge branch 'main' into sft_mcore
yidong72 Oct 6, 2023
3ddd9cd
fix non
yidong72 Oct 7, 2023
a1789e4
try openAI prompt
yidong72 Oct 7, 2023
4db2188
Merge branch 'main' into sft_mcore
yidong72 Oct 7, 2023
d36d3a9
remove unused imports
yidong72 Oct 7, 2023
162be79
remove human labels from the data
yidong72 Oct 9, 2023
700d9f2
use hf dataset to clean
yidong72 Oct 9, 2023
ed68643
reviewer comments
yidong72 Oct 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import asyncio
import datetime
import os
import threading
from functools import partial
Expand Down Expand Up @@ -167,7 +168,11 @@ def remove_padded_prompts(response, nb_paddings):
def main(cfg) -> None:

# trainer required for restoring model parallel models
trainer = Trainer(strategy=NLPDDPStrategy(), **cfg.trainer, callbacks=[CustomProgressBar()])
trainer = Trainer(
strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)),
yidong72 marked this conversation as resolved.
Show resolved Hide resolved
**cfg.trainer,
callbacks=[CustomProgressBar()],
)

if cfg.gpt_model_file is not None:
if (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ model:

data:
chat: False # whether use chatbot data or not
chat_prompt_tokens: # special tokens for the chat prompts, a dictionary of {token_type: token}. note that some tokenizer may combine the characters at the junction between {end_of_turn}{turn_start}. e.g. '<im end><im start>', the '><' sometimes is merged to be a single token. This is not supported, try to avoid
Zhilin123 marked this conversation as resolved.
Show resolved Hide resolved
system_turn_start: '<extra_id_0>'
turn_start: '<extra_id_1>'
label_start: '<extra_id_2>'
end_of_turn: '\n'
train_ds:
# Example of how to specify paths to multiple datasets
# file_names:
Expand Down
1 change: 0 additions & 1 deletion examples/nlp/language_modeling/tuning/megatron_gpt_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector

from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel
from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,41 @@

import torch

from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset
from nemo.utils import logging

__all__ = ['GPTSFTChatDataset']

IGNORE_INDEX = -100
END_SIGNAL = "\n"
END_NAME_SIGNAL = "\n"

SYSTEM_TOKEN = "<extra_id_0>System\n"
TURN_TOKEN = "<extra_id_1>"
PREFIX_STR = (
"\x00" # the prefix string used in the tokenizer to deal with the added empty token for some of the tokenizers
yidong72 marked this conversation as resolved.
Show resolved Hide resolved
)

IGNORE_INDEX = -100
END_NAME_SIGNAL = "\n" # token to indicate the end of the name. The name can be system, user, assistant, etc.
SYSTEM_TOKEN = "System" + END_NAME_SIGNAL

TYPE_INSTRUCTION = {
'TEXT_TO_VALUE': "",
'VALUE_TO_TEXT': '',
}


def find_small_tensor(small, large):
yidong72 marked this conversation as resolved.
Show resolved Hide resolved
""" find the location of the small tensor in the large tensor.
e.g. small = [1,3], large = [2,3,1,3], returns 2
small = [3,2], large = [2,3,1,3], returns -1
Args:
small (tensor): small tensor
large (tensor): large tensor
"""
for i in range(large.size(0) - small.size(0) + 1):
if torch.equal(large[i : i + small.size(0)], small):
return i
return -1


def _mask_targets(
target,
tokenized_lens,
Expand All @@ -45,8 +60,10 @@ def _mask_targets(
tokenizer,
mask_role,
gtype,
extra_id_2_token_id,
new_line_token_id,
special_tokens,
label_start_ids,
num_turn_start_tokens,
):
""" This function masks the tokens so the loss is computed only on the non-masked role's responses.
For 'TEXT_TO_VALUE' type, the loss is computed on the value attributes.
Expand All @@ -60,50 +77,66 @@ def _mask_targets(
tokenizer (TokenizerSpec): tokenizer object
mask_role (str): the speaker id to be masked from loss computation
gtype (str): either 'TEXT_TO_VALUE' or 'VALUE_TO_TEXT'
extra_id_2_token_id (int): <extra_id_2> token id
new_line_token_id (int): new line token id

special_tokens (dict): special tokens used for the chat prompt. It has the keys: system_turn_start, turn_start, label_start, end_of_turn
label_start_ids (list): list of label start token ids,
num_turn_start_tokens (int): number of tokens of the turn_start str
"""
TURN_TOKEN = special_tokens['turn_start']
label_start_ids = torch.tensor(label_start_ids)

cur_idx = header_len
tgt_len = target.shape[0]
for i, (tokenized_len, speaker, s_id) in enumerate(zip(tokenized_lens, speakers, s_ids)):
# note, sentence piece will add extra empty token in front. has to compute the diff
id1 = tokenizer.text_to_ids("<extra_id_1>")
id2 = tokenizer.text_to_ids("<extra_id_1>" + TURN_TOKEN + speaker + END_NAME_SIGNAL)
skip_name_len = len(id2) - len(id1)
if extra_id_2_token_id is None:
raise ValueError("extra_id_2 is not in the vocabulary")
if (s_id == extra_id_2_token_id).any().item():
id1 = tokenizer.text_to_ids(PREFIX_STR)
id2 = tokenizer.text_to_ids(PREFIX_STR + TURN_TOKEN + speaker + END_NAME_SIGNAL)
skip_name_len = len(id2) - len(
id1
) # s_ids[:skip_name_len] is the name part of the prompt 'TURN_TOKEN + speaker + END_NAME_SIGNAL'
# get the position of the label start string in this turn
location = find_small_tensor(label_start_ids, s_id)

if location >= 0:
# if it contains the label start tokens
if gtype == 'VALUE_TO_TEXT':
# if contains the token <extra_id_2>
assert skip_name_len == torch.where((s_id == extra_id_2_token_id))[0].item()
# find new line token id 14
more_skip_len = torch.where((s_id[skip_name_len:] == new_line_token_id))[0][0].item() + 1
# handles the case that condition on labels to generate respone
# the next token after the name part of the prompt is the beginning of the label start tokens
assert skip_name_len == location
# find the first new line token after the label part, which indicates the end of the whole label string
newline_loc = torch.where((s_id[skip_name_len:] == new_line_token_id))[0]
if len(newline_loc) == 0:
# cannot find new line token, which means the the whole turn is just a partial label string. Mask the whole turn
target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX
continue
# skip the label part and the new line token
more_skip_len = newline_loc[0].item() + 1
# skip the name part and the label part
skip_name_len += more_skip_len
elif gtype == 'TEXT_TO_VALUE':
skip_name_len = torch.where((s_id == extra_id_2_token_id))[0].item() + 1
# handles the case that condition on response to generate label
# skip the name part, response and the label start tokens part, the remainder is the label string without label start, e.g. 'quality:9,toxicity:8...'
skip_name_len = location + len(label_start_ids)
if cur_idx >= tgt_len:
break
elif cur_idx + tokenized_len < tgt_len:
# Check whether the mask is applied to the correct position, the first token is turn token: <extra_id_1>
# s_id[2:] skips the artifact empty token and the turn token
# target[cur_idx + 1:cur_idx + tokenized_len] skip the turn token
# Check whether the mask is applied to the correct position, the first token is turn start tokens
if not torch.equal(target[cur_idx + 1 : cur_idx + tokenized_len], s_id[1:]):
logging.warning("a sentence mismatches the corresponding piece " "in the conversation")
if i == 0 and (gtype == 'VALUE_TO_TEXT' or gtype is None):
# mask the first turn completely to provide at least one turn as context
# mask the first turn completely to provide at least one turn as context for the rest
target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX
elif speaker == mask_role and i == 1 and gtype == 'TEXT_TO_VALUE':
# leave the first human tag unmasked
target[cur_idx + 1 : cur_idx + tokenized_len] = IGNORE_INDEX
# leave the first turn start tag unmasked, servers serves as the end of turn signal
yidong72 marked this conversation as resolved.
Show resolved Hide resolved
target[cur_idx + num_turn_start_tokens : cur_idx + tokenized_len] = IGNORE_INDEX
elif speaker == mask_role and (i > 1):
# leave the first human tag unmasked
target[cur_idx + 1 : cur_idx + tokenized_len] = IGNORE_INDEX
# leave the first turn start tag unmasked, which servers as the end of turn signal
yidong72 marked this conversation as resolved.
Show resolved Hide resolved
target[cur_idx + num_turn_start_tokens : cur_idx + tokenized_len] = IGNORE_INDEX
elif speaker == mask_role and (i <= 1):
# mask out everything in the second turn
target[cur_idx : cur_idx + tokenized_len] = IGNORE_INDEX
else:
# mask up to the name end, need to remove one as skip name has an extra artifact empty token
# mask up to name part, label part for VALUE_TO_TEXT, or name part, response and label start tokens for TEXT_TO_VALUE, or just the name part if gtype is None
target[cur_idx : cur_idx + skip_name_len] = IGNORE_INDEX
cur_idx += tokenized_len

Expand All @@ -112,16 +145,20 @@ def cannonical_form_formater(cannoical_form):
return f'<extra_id_2>{cannoical_form}\n'
yidong72 marked this conversation as resolved.
Show resolved Hide resolved


def response_value_formater(label):
def response_value_formater(label, label_start, end_signal):
if isinstance(label, str):
return '<extra_id_2>' + label + '\n'
return label_start + label + end_signal
elif label is None:
return ''
else:
raise ValueError(f'Unknown label type {type(label)}, only str type is supported')


def _add_speaker_and_signal(header, source, mask_role, gtype):
def _add_speaker_and_signal(header, source, mask_role, gtype, special_tokens):
TURN_TOKEN = special_tokens['turn_start']
END_SIGNAL = special_tokens['end_of_turn']
LABEL_START = special_tokens['label_start']

"""Add speaker and start/end signal on each round."""
BEGIN_SIGNAL = ""
conversation = header
Expand All @@ -138,7 +175,11 @@ def _add_speaker_and_signal(header, source, mask_role, gtype):
+ role_token
+ sentence_from
+ END_NAME_SIGNAL
+ (response_value_formater(sentence['label']) if 'label' in sentence else '')
+ (
response_value_formater(sentence['label'], LABEL_START, END_NAME_SIGNAL)
if 'label' in sentence
else ''
)
+ sentence["value"]
+ END_SIGNAL
)
Expand All @@ -150,7 +191,11 @@ def _add_speaker_and_signal(header, source, mask_role, gtype):
+ END_NAME_SIGNAL
+ sentence["value"]
+ END_SIGNAL
+ (response_value_formater(sentence['label']) if 'label' in sentence else '')
+ (
response_value_formater(sentence['label'], LABEL_START, END_NAME_SIGNAL)
if 'label' in sentence
else ''
)
)
else:
raise ValueError(
Expand All @@ -163,44 +208,51 @@ def _add_speaker_and_signal(header, source, mask_role, gtype):
return conversation


def preprocess(source: dict, tokenizer: TokenizerSpec, extra_id_2_token_id: int, new_line_token_id: int):
def preprocess(
source: dict,
tokenizer: TokenizerSpec,
new_line_token_id: int,
label_start_ids: list,
special_tokens: dict,
num_turn_start_tokens: int,
):
"""
Given a conversation list. This transform:
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
2. Concatenate conversations together;
3. Tokenize the concatenated conversation;
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
"""
END_SIGNAL = special_tokens['end_of_turn']
data_type = None
if 'type' in source:
data_type = source['type']
assert data_type in TYPE_INSTRUCTION, f"source type {data_type} not supported"
if data_type is not None:
assert data_type in TYPE_INSTRUCTION, f"source type {data_type} not supported"
# add end signal and concatenate together
conversation = source['system']
if data_type is not None:
if TYPE_INSTRUCTION[data_type] != '':
conversation = conversation + '\n' + TYPE_INSTRUCTION[data_type]
mask_role = source.get('mask', 'User')
header = f"{SYSTEM_TOKEN}{conversation}"
conversation = _add_speaker_and_signal(header, source['conversations'], mask_role, data_type)
header = f"{special_tokens['system_turn_start']}{SYSTEM_TOKEN}{conversation}{END_SIGNAL}"
conversation = _add_speaker_and_signal(header, source['conversations'], mask_role, data_type, special_tokens)
# tokenize conversations
input_ids = tokenizer.text_to_ids(conversation)
target = copy.deepcopy(input_ids)
header_len = len(tokenizer.text_to_ids(header))
header_tokens = tokenizer.text_to_ids(header)
header_len = len(header_tokens)

ids = []
tokenized_lens = []
assert torch.equal(torch.tensor(target[:header_len]), torch.tensor(header_tokens))
for s in source['conversations']:
if isinstance(tokenizer, SentencePieceTokenizer):
tokenized_sentence = tokenizer.text_to_ids(s["value"])
ids.append(torch.tensor(tokenized_sentence)[1:])
# remove one token as it adds an empty token in front
tokenized_lens.append(len(tokenized_sentence) - 1)
else:
tokenized_sentence = tokenizer.text_to_ids(s["value"])
ids.append(torch.tensor(tokenized_sentence))
# remove one token as it adds an empty token in front
tokenized_lens.append(len(tokenized_sentence))
# hack to remove the extra empty token in front
id1 = tokenizer.text_to_ids(PREFIX_STR + s["value"])
id2 = tokenizer.text_to_ids(PREFIX_STR)
tokenized_sentence = id1[len(id2) :]
ids.append(torch.tensor(tokenized_sentence))
tokenized_lens.append(len(tokenized_sentence))
speakers = [sentence["from"] for sentence in source['conversations']]
assert mask_role in speakers, "mask role not in the conversation"
target = torch.LongTensor(target)
Expand All @@ -216,8 +268,10 @@ def preprocess(source: dict, tokenizer: TokenizerSpec, extra_id_2_token_id: int,
tokenizer,
mask_role,
data_type,
extra_id_2_token_id,
new_line_token_id,
special_tokens,
label_start_ids,
num_turn_start_tokens,
)
mask = (target != IGNORE_INDEX).bool()
assert mask.sum().item() != 0, "mask is empty"
Expand All @@ -228,45 +282,39 @@ def preprocess(source: dict, tokenizer: TokenizerSpec, extra_id_2_token_id: int,
return dict(input_ids=input_ids, mask=mask, context_ids=context_ids, answer_ids=answer_ids)


def _check_token_in_vocab(tokenizer, token):
ids = tokenizer.text_to_ids(token)
if isinstance(tokenizer, SentencePieceTokenizer):
return len(ids) == 2
else:
return len(ids) == 1


class GPTSFTChatDataset(GPTSFTDataset):
def _maybe_validate_prompt_template(self):
pass

def _build_samples_mapping(self):
super()._build_samples_mapping()
assert hasattr(self.tokenizer, "vocab"), "tokenizer should have vocab property, not supported"
assert _check_token_in_vocab(
self.tokenizer, '<extra_id_0>'
), "<extra_id_0> not in the tokenizer vocab. not supported"
assert _check_token_in_vocab(
self.tokenizer, '<extra_id_1>'
), "<extra_id_1> not in the tokenizer vocab. not supported"
# calcuilate <extra_id_2> id value
if _check_token_in_vocab(self.tokenizer, '<extra_id_2>'):
ids_1 = self.tokenizer.text_to_ids('<extra_id_1><extra_id_2>')
ids_2 = self.tokenizer.text_to_ids('<extra_id_1>')
self.extra_id_2_token_id = ids_1[len(ids_2) :][0]
else:
self.extra_id_2_token_id = None
ids_1 = self.tokenizer.text_to_ids('<extra_id_1>\n')
ids_2 = self.tokenizer.text_to_ids('<extra_id_1>')
LABEL_START = self.special_tokens['label_start']
id1 = self.tokenizer.text_to_ids(PREFIX_STR)
id2 = self.tokenizer.text_to_ids(PREFIX_STR + LABEL_START)
self.label_start_tokens = id2[len(id1) :]
ids_1 = self.tokenizer.text_to_ids(PREFIX_STR + '\n')
ids_2 = self.tokenizer.text_to_ids(PREFIX_STR)
self.new_line_token_id = ids_1[len(ids_2) :][0]

ids_1 = self.tokenizer.text_to_ids(PREFIX_STR + self.special_tokens['turn_start'])
yidong72 marked this conversation as resolved.
Show resolved Hide resolved
ids_2 = self.tokenizer.text_to_ids(PREFIX_STR)
self.num_turn_start_tokens = len(ids_1) - len(ids_2)

def _process_example(self, example):
"""
Create an example by concatenating text and answer.
Truncation is carried out when needed, but it is performed only on the prompt side.
BOS, EOS, and SEP, are added if specified.
"""
result = preprocess(example, self.tokenizer, self.extra_id_2_token_id, self.new_line_token_id)
result = preprocess(
example,
self.tokenizer,
self.new_line_token_id,
self.label_start_tokens,
self.special_tokens,
self.num_turn_start_tokens,
)

# store metadata in dataset, in case user may have keys required in the prediction json files
metadata = {k: v for k, v in example.items() if k not in ['conversations']}
Expand Down
Loading
Loading