Skip to content

Commit

Permalink
add chat template data handler
Browse files Browse the repository at this point in the history
Signed-off-by: Dushyant Behl <dushyantbehl@in.ibm.com>
  • Loading branch information
dushyantbehl committed Dec 18, 2024
1 parent 1ab11e7 commit ba9137e
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 43 deletions.
4 changes: 2 additions & 2 deletions tests/artifacts/testdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
TWITTER_COMPLAINTS_TOKENIZED_PARQUET = os.path.join(
PARQUET_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.parquet"
)
CHAT_DATA_SINGLE_TURN = os.path.join(JSON_DATA_DIR, "single_turn_chat.jsonl")
CHAT_DATA_MULTI_TURN = os.path.join(JSON_DATA_DIR, "multi_turn_chat.jsonl")
CHAT_DATA_SINGLE_TURN = os.path.join(JSONL_DATA_DIR, "single_turn_chat.jsonl")
CHAT_DATA_MULTI_TURN = os.path.join(JSONL_DATA_DIR, "multi_turn_chat.jsonl")
EMPTY_DATA = os.path.join(JSON_DATA_DIR, "empty_data.json")
MALFORMATTED_DATA = os.path.join(JSON_DATA_DIR, "malformatted_data.json")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
{
"<pad>": 32003,
"<|assistant|>": 32001,
"<|system|>": 32002,
"<|user|>": 32000
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,8 @@
{
"additional_special_tokens": [
{
"content": "<|user|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
{
"content": "<|assistant|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
{
"content": "<|system|>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
}
"<|user|>",
"<|assistant|>",
"<|system|>"
],
"bos_token": {
"content": "<s>",
Expand All @@ -36,6 +18,13 @@
"rstrip": false,
"single_word": false
},
"pad_token": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false
},
"unk_token": {
"content": "<unk>",
"lstrip": false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 32003,
"content": "<pad>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
],
"normalizer": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@
"rstrip": false,
"single_word": false,
"special": true
},
"32003": {
"content": "<pad>",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
"special": true
}
},
"additional_special_tokens": [
Expand All @@ -62,7 +70,7 @@
"eos_token": "</s>",
"legacy": true,
"model_max_length": 2048,
"pad_token": null,
"pad_token": "<pad>",
"sp_model_kwargs": {},
"tokenizer_class": "LlamaTokenizer",
"unk_token": "<unk>",
Expand Down
51 changes: 45 additions & 6 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,10 +888,23 @@ def test_run_chat_style_ft(dataset_path):
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args)
sft_trainer.train(model_args, data_args, train_args)

# validate full ft configs
# validate the configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)

# Run inference on the text
output_inference = loaded_model.run(
'<|user|>\nProvide two rhyming words for the word "love"\n\
<nopace></s><|assistant|>',
max_new_tokens=50,
)
assert len(output_inference) > 0
assert 'Provide two rhyming words for the word "love"' in output_inference


@pytest.mark.parametrize(
Expand All @@ -917,7 +930,20 @@ def test_run_chat_style_ft_using_dataconfig(datafiles, dataconfigfile):
{% if loop.last and add_generation_prompt %}{{ '<|assistant|>' }}\
{% endif %}\
{% endfor %}"
data_args.response_template = "<|assistant|>"
data_args.instruction_template = "<|user|>"
data_args.dataset_text_field = "new_formatted_field"

handler_kwargs = {"dataset_text_field": data_args.dataset_text_field}
kwargs = {
"fn_kwargs": handler_kwargs,
"batched": False,
"remove_columns": "all",
}

handler_config = DataHandlerConfig(
name="apply_tokenizer_chat_template", arguments=kwargs
)

model_args = copy.deepcopy(MODEL_ARGS)
model_args.tokenizer_name_or_path = CUSTOM_TOKENIZER_TINYLLAMA
Expand All @@ -932,16 +958,29 @@ def test_run_chat_style_ft_using_dataconfig(datafiles, dataconfigfile):
data = yaml.safe_load(f)
datasets = data["datasets"]
for i, d in enumerate(datasets):
d["data_paths"][0] = datafiles[i]
d["data_paths"] = [datafiles[i]]
# Basic chat datasets don't need data handling
del d["data_handlers"]
d["data_handlers"] = [asdict(handler_config)]
yaml.dump(data, temp_yaml_file)
data_args.data_config_path = temp_yaml_file.name

sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args)
sft_trainer.train(model_args, data_args, train_args)

# validate full ft configs
# validate the configs
_validate_training(tempdir)
checkpoint_path = _get_checkpoint_path(tempdir)

# Load the model
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)

# Run inference on the text
output_inference = loaded_model.run(
'<|user|>\nProvide two rhyming words for the word "love"\n\
<nopace></s><|assistant|>',
max_new_tokens=50,
)
assert len(output_inference) > 0
assert 'Provide two rhyming words for the word "love"' in output_inference


############################# Helper functions #############################
Expand Down
21 changes: 13 additions & 8 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,6 @@ class DataArguments:
default=None,
metadata={"help": "Path to the training data in JSON/JSONL format."},
)
response_template: str = field(
default=None,
metadata={"help": "Response template, separator to train on completions only"},
)
dataset_text_field: str = field(
default=None,
metadata={
Expand Down Expand Up @@ -105,16 +101,25 @@ class DataArguments:
chat_template: str = field(
default=None,
metadata={
"help": "chat template to use for tokenization. \
No need to pass this if the tokenizer already has a chat_template \
if passed, it will overwrite tokenizer.chat_template if it exists"
"help": "Chat template to use for tokenization. \
No need to pass this if the tokenizer already has a chat_template. \
If passed, it will overwrite tokenizer.chat_template if it exists."
},
)
response_template: str = field(
default=None,
metadata={
"help": "For completions only style training represents separator to train on; \
For chat style training this needs to be passed with instruction_template\
as piece of text which determines the start of assistant response"
},
)
instruction_template: str = field(
default=None,
metadata={
"help": "Should be provided for chat training. \
Piece of text that determines the start of human response"
Piece of text that determines the start of human response\
Passed in conjunction with response_template"
},
)

Expand Down
17 changes: 17 additions & 0 deletions tuning/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,25 @@ def replace_text(match_obj):
}


def apply_tokenizer_chat_template(
element: Dict[str, str],
tokenizer: AutoTokenizer,
dataset_text_field: str,
**kwargs,
):
if tokenizer.chat_template is None:
raise ValueError(
"Tokenizer does not contain tokenizer.chat_template\
please pass data_args.chat_template"
)
return {
f"{dataset_text_field}": tokenizer.apply_chat_template(element, tokenize=False)
}


AVAILABLE_DATA_HANDLERS = {
"tokenize_and_apply_input_masking": tokenize_and_apply_input_masking,
"apply_dataset_formatting": apply_dataset_formatting,
"apply_custom_data_formatting_template": apply_custom_data_formatting_template,
"apply_tokenizer_chat_template": apply_tokenizer_chat_template,
}
31 changes: 26 additions & 5 deletions tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,25 @@ def _get_dataset_formatting_handlers(data_args, packing):
return [handler], dataset_text_field


### Default Format 3
def _get_chat_dataset_handlers(data_args, tokenizer_kwargs):

if data_args.dataset_text_field is None:
data_args.dataset_text_field = "new_formatted_field"

fn_kwargs = {}
fn_kwargs["dataset_text_field"] = data_args.dataset_text_field
fn_kwargs["tokenizer_kwargs"] = tokenizer_kwargs

kwargs = {"fn_kwargs": fn_kwargs, "batched": False, "remove_columns": "all"}

handlers = [
DataHandlerConfig("apply_tokenizer_chat_template", arguments=kwargs),
]

return handlers, data_args.dataset_text_field


### Default Data format
def _get_default_dataset_handlers(data_args, tokenizer_kwargs):

Expand Down Expand Up @@ -236,15 +255,17 @@ def _process_raw_data_args(
handlers, dataset_text_field = _get_pretokenized_dataset_handlers(
data_args, packing, (is_eval_dataset_present and not is_evaldata_tokenized)
)
elif data_args.instruction_template and data_args.response_template:
# Data Format 2: Chat dataset with instruction and response template
# We don't do processing for chat dataset
handlers, dataset_text_field = _get_chat_dataset_handlers(
data_args, tokenizer_kwargs
)
elif data_args.data_formatter_template or data_args.dataset_text_field:
# Data Format 2: Single Sequence Dataset
# Data Format 3: Single Sequence Dataset
handlers, dataset_text_field = _get_dataset_formatting_handlers(
data_args, packing
)
elif data_args.instruction_template and data_args.response_template:
# Data Format 3: Chat dataset with instruction and response template
# We don't do processing for chat dataset
handlers, dataset_text_field = [], None
else:
# Default Data Format: Dataset with Input/Output Fields
handlers, dataset_text_field = _get_default_dataset_handlers(
Expand Down

0 comments on commit ba9137e

Please sign in to comment.