diff --git a/src/sparseml/transformers/finetune/data/data_helpers.py b/src/sparseml/transformers/finetune/data/data_helpers.py
index 243f4085023..31dcb53a920 100644
--- a/src/sparseml/transformers/finetune/data/data_helpers.py
+++ b/src/sparseml/transformers/finetune/data/data_helpers.py
@@ -128,9 +128,12 @@ def make_dataset_splits(
     train_split = eval_split = predict_split = calib_split = None
 
     if do_train:
-        if "train" not in tokenized_datasets:
+        if "train" in tokenized_datasets:
+            train_split = tokenized_datasets["train"]
+        elif "train_sft" in tokenized_datasets:
+            train_split = tokenized_datasets["train_sft"]
+        else:
             raise ValueError("--do_train requires a train dataset")
-        train_split = tokenized_datasets["train"]
     if do_eval:
         if "validation" not in tokenized_datasets:
             raise ValueError("--do_eval requires a validation dataset")
@@ -142,7 +145,11 @@ def make_dataset_splits(
     if do_oneshot:
         calib_split = tokenized_datasets.get("calibration")
         if calib_split is None:
-            if "train" not in tokenized_datasets:
+            if "train" in tokenized_datasets:
+                train_split = tokenized_datasets["train"]
+            elif "train_sft" in tokenized_datasets:
+                train_split = tokenized_datasets["train_sft"]
+            else:
                 raise ValueError("--do_oneshot requires a calibration dataset")
             calib_split = tokenized_datasets["train"]