Skip to content

Commit

Permalink
add the chat prompt format into the config
Browse files Browse the repository at this point in the history
Signed-off-by: Yi Dong <yidong@nvidia.com>
  • Loading branch information
yidong72 committed Oct 6, 2023
1 parent e293336 commit c99b55f
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,9 @@ def get_prompt_template_example(self):
'system': '{system message}',
'conversations': [
{'from': 'User', 'value': '{turn 1 user message}', 'label': None},
{'from': 'Assitant', 'value': '{turn 1 assistant message}', 'label': '{turn 1 assistant label}'},
{'from': 'Assistant', 'value': '{turn 1 assistant message}', 'label': '{turn 1 assistant label}'},
{'from': 'User', 'value': '{turn 2 user message}', 'label': None},
{'from': 'Assitant', 'value': '{turn 2 assistant message}', 'label': '{turn 2 assistant label}'},
{'from': 'Assistant', 'value': '{turn 2 assistant message}', 'label': '{turn 2 assistant label}'},
],
"mask": "User",
"type": "VALUE_TO_TEXT",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch
from omegaconf import DictConfig, ListConfig
from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.common.metrics import MetricStringToTorchMetric
Expand All @@ -33,7 +34,6 @@
from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.collections.nlp.modules.common.megatron.utils import get_iterator_k_split
from nemo.collections.nlp.modules.common.text_generation_utils import generate, get_computeprob_response

from nemo.collections.nlp.parts.mixins.nlp_adapter_mixins import NLPAdapterModelMixin
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo.utils import AppState, logging
Expand Down Expand Up @@ -300,7 +300,11 @@ def _build_dataset(self, data_cfg, is_train=True):
'chat_prompt_tokens', None
), # special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '<extra_id_0>', 'turn_start': '<extra_id_1>', 'label_start': '<extra_id_2>', 'end_of_turn': '\n', "end_of_name": "\n"}
)
datasets.append(dataset)
if self.cfg.data.get("chat", False):
# chat dataset, overwrite the prompt template with the one from the dataset
OmegaConf.set_struct(self.cfg, True)
with open_dict(self.cfg):
self.cfg.prompt_format = dataset.get_prompt_template_example()

if is_train:
dataset = BlendableDataset(
Expand Down
49 changes: 49 additions & 0 deletions tests/collections/nlp/test_chat_sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,55 @@ def _mask_assistant_nolabel_test(self, tokenizer, ids_to_text):
finally:
os.remove(temp_file)

def _test_example_prompt(self, tokenizer):
random.seed(5)
temp_file = '/tmp/test_file.jsonl'
turn_num = 5
records = 5
try:
data_points = create_data_points(True, turn_num, records, temp_file, t2v=False)

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable data_points is not used.
d = GPTSFTChatDataset(
temp_file,
tokenizer,
4096,
1,
index_mapping_dir='/tmp/',
hf_dataset=True,
special_tokens=self.special_tokens,
)
conv = d.get_prompt_template_example()
expected = (
self.special_tokens['system_turn_start']
+ 'System'
+ self.special_tokens['end_of_name']
+ '{system message}'
+ self.special_tokens['end_of_turn']
)
for turn in range(2):
expected += (
self.special_tokens['turn_start']
+ 'User'
+ self.special_tokens['end_of_name']
+ f'{{turn {turn + 1} user message}}'
+ self.special_tokens['end_of_turn']
)
expected += self.special_tokens['turn_start'] + 'Assistant' + self.special_tokens['end_of_name']
expected += (
self.special_tokens['label_start']
+ f'{{turn {turn + 1} assistant label}}'
+ self.special_tokens['end_of_name']
)
expected += f'{{turn {turn + 1} assistant message}}' + self.special_tokens['end_of_turn']
expected += self.special_tokens['turn_start']
assert conv == expected
finally:
os.remove(temp_file)

@pytest.mark.unit
def test_43B_example_prompt(self):
tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B)
self._test_example_prompt(tokenizer)

@pytest.mark.unit
def test_43B_tokenizer_mask_user(self):
tokenizer = get_nmt_tokenizer(library='sentencepiece', tokenizer_model=TOKENIZER_FILE_43B)
Expand Down

0 comments on commit c99b55f

Please sign in to comment.