diff --git a/tests/artifacts/testdata/__init__.py b/tests/artifacts/testdata/__init__.py index 32cba62d3..b6d8c0fff 100644 --- a/tests/artifacts/testdata/__init__.py +++ b/tests/artifacts/testdata/__init__.py @@ -60,8 +60,8 @@ TWITTER_COMPLAINTS_TOKENIZED_PARQUET = os.path.join( PARQUET_DATA_DIR, "twitter_complaints_tokenized_with_maykeye_tinyllama_v0.parquet" ) -CHAT_DATA_SINGLE_TURN = os.path.join(JSON_DATA_DIR, "single_turn_chat.jsonl") -CHAT_DATA_MULTI_TURN = os.path.join(JSON_DATA_DIR, "multi_turn_chat.jsonl") +CHAT_DATA_SINGLE_TURN = os.path.join(JSONL_DATA_DIR, "single_turn_chat.jsonl") +CHAT_DATA_MULTI_TURN = os.path.join(JSONL_DATA_DIR, "multi_turn_chat.jsonl") EMPTY_DATA = os.path.join(JSON_DATA_DIR, "empty_data.json") MALFORMATTED_DATA = os.path.join(JSON_DATA_DIR, "malformatted_data.json") diff --git a/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/added_tokens.json b/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/added_tokens.json index 9ff990cbc..3b01326a4 100644 --- a/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/added_tokens.json +++ b/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/added_tokens.json @@ -1,4 +1,5 @@ { + "": 32003, "<|assistant|>": 32001, "<|system|>": 32002, "<|user|>": 32000 diff --git a/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/special_tokens_map.json b/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/special_tokens_map.json index 90ffa1c36..160e09e9c 100644 --- a/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/special_tokens_map.json +++ b/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/special_tokens_map.json @@ -1,26 +1,8 @@ { "additional_special_tokens": [ - { - "content": "<|user|>", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - }, - { - "content": "<|assistant|>", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - }, - { - "content": "<|system|>", - "lstrip": false, - "normalized": false, - "rstrip": false, - "single_word": false - } + "<|user|>", + "<|assistant|>", + "<|system|>" ], "bos_token": { "content": "", @@ -36,6 +18,13 @@ "rstrip": false, "single_word": false }, + "pad_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, "unk_token": { "content": "", "lstrip": false, diff --git a/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/tokenizer.json b/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/tokenizer.json index 1543ffa79..dc4d81331 100644 --- a/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/tokenizer.json +++ b/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/tokenizer.json @@ -56,6 +56,15 @@ "rstrip": false, "normalized": false, "special": true + }, + { + "id": 32003, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true } ], "normalizer": { diff --git a/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/tokenizer_config.json b/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/tokenizer_config.json index 4ce40d8f8..9a7c465a2 100644 --- a/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/tokenizer_config.json +++ b/tests/artifacts/testdata/tinyllama_tokenizer_special_tokens/tokenizer_config.json @@ -50,6 +50,14 @@ "rstrip": false, "single_word": false, "special": true + }, + "32003": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true } }, "additional_special_tokens": [ @@ -62,7 +70,7 @@ "eos_token": "", "legacy": true, "model_max_length": 2048, - "pad_token": null, + "pad_token": "", "sp_model_kwargs": {}, "tokenizer_class": "LlamaTokenizer", "unk_token": "", diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 6653a71f5..7ca54bd93 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -888,7 +888,7 @@ def test_run_chat_style_ft(dataset_path): train_args = copy.deepcopy(TRAIN_ARGS) train_args.output_dir = tempdir - sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args) + sft_trainer.train(model_args, data_args, train_args) # validate full ft configs _validate_training(tempdir) @@ -917,7 +917,20 @@ def test_run_chat_style_ft_using_dataconfig(datafiles, dataconfigfile): {% if loop.last and add_generation_prompt %}{{ '<|assistant|>' }}\ {% endif %}\ {% endfor %}" + data_args.response_template = "<|assistant|>" data_args.instruction_template = "<|user|>" + data_args.dataset_text_field = "new_formatted_field" + + handler_kwargs = {"dataset_text_field": data_args.dataset_text_field} + kwargs = { + "fn_kwargs": handler_kwargs, + "batched": False, + "remove_columns": "all", + } + + handler_config = DataHandlerConfig( + name="apply_tokenizer_chat_template", arguments=kwargs + ) model_args = copy.deepcopy(MODEL_ARGS) model_args.tokenizer_name_or_path = CUSTOM_TOKENIZER_TINYLLAMA @@ -932,13 +945,13 @@ def test_run_chat_style_ft_using_dataconfig(datafiles, dataconfigfile): data = yaml.safe_load(f) datasets = data["datasets"] for i, d in enumerate(datasets): - d["data_paths"][0] = datafiles[i] + d["data_paths"] = [datafiles[i]] # Basic chat datasets don't need data handling - del d["data_handlers"] + d["data_handlers"] = [asdict(handler_config)] yaml.dump(data, temp_yaml_file) data_args.data_config_path = temp_yaml_file.name - sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args) + sft_trainer.train(model_args, data_args, train_args) # validate full ft configs _validate_training(tempdir) diff --git a/tuning/config/configs.py b/tuning/config/configs.py index 6786d5410..32c5315f9 100644 --- a/tuning/config/configs.py +++ b/tuning/config/configs.py @@ -69,10 +69,6 @@ class DataArguments: default=None, metadata={"help": "Path to the training data in JSON/JSONL format."}, ) - response_template: str = field( - default=None, - metadata={"help": "Response template, separator to train on completions only"}, - ) dataset_text_field: str = field( default=None, metadata={ @@ -105,16 +101,25 @@ class DataArguments: chat_template: str = field( default=None, metadata={ - "help": "chat template to use for tokenization. \ - No need to pass this if the tokenizer already has a chat_template \ - if passed, it will overwrite tokenizer.chat_template if it exists" + "help": "Chat template to use for tokenization. \ + No need to pass this if the tokenizer already has a chat_template. \ + If passed, it will overwrite tokenizer.chat_template if it exists." + }, + ) + response_template: str = field( + default=None, + metadata={ + "help": "For completions only style training represents separator to train on; \ + For chat style training this needs to be passed with instruction_template\ + as piece of text which determines the start of assistant response" }, ) instruction_template: str = field( default=None, metadata={ "help": "Should be provided for chat training. \ - Piece of text that determines the start of human response" + Piece of text that determines the start of human response\ + Passed in conjunction with response_template" }, ) diff --git a/tuning/data/data_handlers.py b/tuning/data/data_handlers.py index 6a821ec5c..d666a6e76 100644 --- a/tuning/data/data_handlers.py +++ b/tuning/data/data_handlers.py @@ -137,8 +137,25 @@ def replace_text(match_obj): } +def apply_tokenizer_chat_template( + element: Dict[str, str], + tokenizer: AutoTokenizer, + dataset_text_field: str, + **kwargs, +): + if tokenizer.chat_template is None: + raise ValueError( + "Tokenizer does not contain tokenizer.chat_template\ + please pass data_args.chat_template" + ) + return { + f"{dataset_text_field}": tokenizer.apply_chat_template(element, tokenize=False) + } + + AVAILABLE_DATA_HANDLERS = { "tokenize_and_apply_input_masking": tokenize_and_apply_input_masking, "apply_dataset_formatting": apply_dataset_formatting, "apply_custom_data_formatting_template": apply_custom_data_formatting_template, + "apply_tokenizer_chat_template": apply_tokenizer_chat_template, } diff --git a/tuning/data/setup_dataprocessor.py b/tuning/data/setup_dataprocessor.py index 3ba3c9e5f..2bae18bb7 100644 --- a/tuning/data/setup_dataprocessor.py +++ b/tuning/data/setup_dataprocessor.py @@ -151,6 +151,25 @@ def _get_dataset_formatting_handlers(data_args, packing): return [handler], dataset_text_field +### Default Format 3 +def _get_chat_dataset_handlers(data_args, tokenizer_kwargs): + + if data_args.dataset_text_field is None: + data_args.dataset_text_field = "new_formatted_field" + + fn_kwargs = {} + fn_kwargs["dataset_text_field"] = data_args.dataset_text_field + fn_kwargs["tokenizer_kwargs"] = tokenizer_kwargs + + kwargs = {"fn_kwargs": fn_kwargs, "batched": False, "remove_columns": "all"} + + handlers = [ + DataHandlerConfig("apply_tokenizer_chat_template", arguments=kwargs), + ] + + return handlers, data_args.dataset_text_field + + ### Default Data format def _get_default_dataset_handlers(data_args, tokenizer_kwargs): @@ -236,15 +255,17 @@ def _process_raw_data_args( handlers, dataset_text_field = _get_pretokenized_dataset_handlers( data_args, packing, (is_eval_dataset_present and not is_evaldata_tokenized) ) + elif data_args.instruction_template and data_args.response_template: + # Data Format 2: Chat dataset with instruction and response template + # We don't do processing for chat dataset + handlers, dataset_text_field = _get_chat_dataset_handlers( + data_args, tokenizer_kwargs + ) elif data_args.data_formatter_template or data_args.dataset_text_field: - # Data Format 2: Single Sequence Dataset + # Data Format 3: Single Sequence Dataset handlers, dataset_text_field = _get_dataset_formatting_handlers( data_args, packing ) - elif data_args.instruction_template and data_args.response_template: - # Data Format 3: Chat dataset with instruction and response template - # We don't do processing for chat dataset - handlers, dataset_text_field = [], None else: # Default Data Format: Dataset with Input/Output Fields handlers, dataset_text_field = _get_default_dataset_handlers(