From 769abb352e122f55e77f90611a99a105bab64c8d Mon Sep 17 00:00:00 2001
From: Alexandre Marques <alexandre@neuralmagic.com>
Date: Mon, 1 Apr 2024 15:11:42 -0400
Subject: [PATCH] Updates to enable ultrachat200k

Ultrachat200k has 2 splits for training, one for sft and another for dpo. As a result it doesn't have a "train" split per se. This PR allows for a train_sft alternative.
---
 .../transformers/finetune/data/data_helpers.py      | 13 ++++++++++---
 1 file changed, 10 insertions(+), 3 deletions(-)

diff --git a/src/sparseml/transformers/finetune/data/data_helpers.py b/src/sparseml/transformers/finetune/data/data_helpers.py
index 243f4085023..31dcb53a920 100644
--- a/src/sparseml/transformers/finetune/data/data_helpers.py
+++ b/src/sparseml/transformers/finetune/data/data_helpers.py
@@ -128,9 +128,12 @@ def make_dataset_splits(
     train_split = eval_split = predict_split = calib_split = None
 
     if do_train:
-        if "train" not in tokenized_datasets:
+        if "train" in tokenized_datasets:
+            train_split = tokenized_datasets["train"]
+        elif "train_sft" in tokenized_datasets:
+            train_split = tokenized_datasets["train_sft"]
+        else:
             raise ValueError("--do_train requires a train dataset")
-        train_split = tokenized_datasets["train"]
     if do_eval:
         if "validation" not in tokenized_datasets:
             raise ValueError("--do_eval requires a validation dataset")
@@ -142,7 +145,11 @@ def make_dataset_splits(
     if do_oneshot:
         calib_split = tokenized_datasets.get("calibration")
         if calib_split is None:
-            if "train" not in tokenized_datasets:
+            if "train" in tokenized_datasets:
+                train_split = tokenized_datasets["train"]
+            elif "train_sft" in tokenized_datasets:
+                train_split = tokenized_datasets["train_sft"]
+            else:
                 raise ValueError("--do_oneshot requires a calibration dataset")
             calib_split = tokenized_datasets["train"]