For local files, the finetuning script supports the following input file formats: json
and jsonl
(one json per line). By default, the script expects the following key name:
source
- Key for each conversation
You can specify custom key name using the flag --dialogue_key <key_name>
to run_dialogue_generation.py
. To view sample input files, see the files here.
To see list of all available options, do python run_dialogue_generation.py -h
. There are three ways to provide input data files to the script:
- with flag
--dataset_dir <path>
where<path>
points to the directory containing files with prefixtrain
,validation
andtest
. - with flags
--train_file <path>
/--train_file <path>
/--validation_file <path>
/--test_file <path>
. - with a dataset from Huggingface Datasets Library, usng the keys
--dataset_name <name>
anddataset_config_name <name>
(optional)
For the following commands, we are going to use the --dataset_dir <path>
to provide input files.
For finetuning and inference on the test set using the best model during validation (on single GPU), a minimal example is as follows:
$ python ./run_dialogue_generation.py \
--model_name_or_path "csebuetnlp/banglat5" \
--dataset_dir "sample_inputs/" \
--output_dir "outputs/" \
--learning_rate=5e-4 \
--warmup_steps 5000 \
--label_smoothing_factor 0.1 \
--gradient_accumulation_steps 4 \
--weight_decay 0.1 \
--lr_scheduler_type "linear" \
--per_device_train_batch_size=8 \
--per_device_eval_batch_size=8 \
--max_source_length 256 \
--max_target_length 256 \
--logging_strategy "epoch" \
--save_strategy "epoch" \
--evaluation_strategy "epoch" \
--greater_is_better true --load_best_model_at_end \
--metric_for_best_model sacrebleu --evaluation_metric sacrebleu \
--num_train_epochs=5 \
--do_train --do_eval --do_predict \
--predict_with_generate
- To calculate metrics on test set / inference on raw data, use the following snippet:
$ python ./run_dialogue_generation.py \
--model_name_or_path <path/to/trained/model> \
--dataset_dir "sample_inputs/" \
--output_dir "outputs/" \
--per_device_eval_batch_size=8 \
--overwrite_output_dir \
--evaluation_metric sacrebleu \
--do_predict --predict_with_generate