diff --git a/recipes/quickstart/finetuning/datasets/README.md b/recipes/quickstart/finetuning/datasets/README.md index 8795ca96d..2255d743a 100644 --- a/recipes/quickstart/finetuning/datasets/README.md +++ b/recipes/quickstart/finetuning/datasets/README.md @@ -34,7 +34,7 @@ def get_custom_dataset(dataset_config, tokenizer, split: str): ``` For an example `get_custom_dataset` you can look at the provided datasets in llama_recipes.datasets or [custom_dataset.py](./custom_dataset.py). The `dataset_config` in the above signature will be an instance of llama_recipes.configs.dataset.custom_dataset with the modifications made through the command line. -The split signals wether to return the training or validation dataset. +The split signals whether to return the training or validation dataset. The default function name is `get_custom_dataset` but this can be changed as described below. In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter. @@ -47,6 +47,8 @@ python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.f ``` This will call the function `get_foo` instead of `get_custom_dataset` when retrieving the dataset. +If you need to use a custom data collator, name it `get_data_collator` in the same file as `get_foo`. + ### Adding new dataset Each dataset has a corresponding configuration (dataclass) in [configs/datasets.py](../../../../src/llama_recipes/configs/datasets.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc. diff --git a/src/llama_recipes/datasets/custom_dataset.py b/src/llama_recipes/datasets/custom_dataset.py index 278fcfe54..7a4a453b5 100644 --- a/src/llama_recipes/datasets/custom_dataset.py +++ b/src/llama_recipes/datasets/custom_dataset.py @@ -37,7 +37,7 @@ def get_custom_dataset(dataset_config, tokenizer, split: str): def get_data_collator(dataset_processer,dataset_config): if ":" in dataset_config.file: - module_path, func_name = dataset_config.file.split(":") + module_path, func_name = dataset_config.file.split(":")[0], "get_data_collator" else: module_path, func_name = dataset_config.file, "get_data_collator"