@@ -822,18 +822,19 @@ def __init__(
822822 )
823823
824824 # Loss function
825- if args .loss_type == "nll" :
826- pass # use the default loss
827- elif args .loss_type == "dft" :
828- if compute_loss_func is not None :
829- raise ValueError (
830- "You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. "
831- "When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so passing a "
832- "`compute_loss_func` is not allowed."
833- )
834- compute_loss_func = dft_loss
835- else :
836- raise ValueError (f"Invalid `loss_type` { args .loss_type } passed. Supported values are 'nll' and 'dft'." )
825+ if not args .use_liger_kernel : # liger supports dft loss by just passing use_token_scaling=True
826+ if args .loss_type == "nll" :
827+ pass # use the default loss
828+ elif args .loss_type == "dft" :
829+ if compute_loss_func is not None :
830+ raise ValueError (
831+ "You passed a `compute_loss_func` together with `loss_type='dft'` to the `SFTTrainer`. "
832+ "When using `loss_type='dft'`, the loss function is internally set to the DFT loss, so "
833+ "passing a `compute_loss_func` is not allowed."
834+ )
835+ compute_loss_func = dft_loss
836+ else :
837+ raise ValueError (f"Invalid `loss_type` { args .loss_type } passed. Supported values are 'nll' and 'dft'." )
837838
838839 # Initialize the metrics
839840 self ._metrics = {"train" : defaultdict (list ), "eval" : defaultdict (list )}
@@ -1113,6 +1114,11 @@ def compute_loss(
11131114
11141115 # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
11151116 inputs ["use_cache" ] = False
1117+ # Request token accuracy from Liger kernel and set token scaling if using DFT loss
1118+ if self .args .use_liger_kernel :
1119+ inputs ["return_token_accuracy" ] = True
1120+ inputs ["use_token_scaling" ] = self .args .loss_type == "dft"
1121+
11161122 (loss , outputs ) = super ().compute_loss (
11171123 model , inputs , return_outputs = True , num_items_in_batch = num_items_in_batch
11181124 )
@@ -1151,8 +1157,11 @@ def compute_loss(
11511157 self ._total_train_tokens += num_tokens_in_batch
11521158 self ._metrics [mode ]["num_tokens" ] = [self ._total_train_tokens ]
11531159
1154- # Compute token accuracy if we have labels and if the model is not using Liger (no logits)
1155- if not self .args .use_liger_kernel :
1160+ if self .args .use_liger_kernel :
1161+ token_accuracy = self .accelerator .gather_for_metrics (outputs .token_accuracy ).mean ().item ()
1162+ self ._metrics [mode ]["mean_token_accuracy" ].append (token_accuracy )
1163+ else :
1164+ # Compute accuracy from logits using argmax (traditional method)
11561165 with torch .no_grad ():
11571166 if "shift_labels" in inputs :
11581167 # When using CP, labels are pre-shifted. We must use these (and cannot manually shift) because:
@@ -1190,10 +1199,12 @@ def compute_loss(
11901199 total_sum = total_tokens .sum ()
11911200 accuracy = (correct_tokens .sum () / total_sum ).item () if total_sum > 0 else 0.0
11921201 self ._metrics [mode ]["mean_token_accuracy" ].append (accuracy )
1193- if self .aux_loss_enabled :
1194- aux_loss = outputs .aux_loss
1195- aux_loss = self .accelerator .gather_for_metrics (aux_loss ).mean ().item ()
1196- self ._metrics [mode ]["aux_loss" ].append (aux_loss )
1202+
1203+ # Log auxiliary loss if enabled (applies to both Liger and non-Liger)
1204+ if self .aux_loss_enabled :
1205+ aux_loss = outputs .aux_loss
1206+ aux_loss = self .accelerator .gather_for_metrics (aux_loss ).mean ().item ()
1207+ self ._metrics [mode ]["aux_loss" ].append (aux_loss )
11971208
11981209 return (loss , outputs ) if return_outputs else loss
11991210
0 commit comments