diff --git a/eole/config/training.py b/eole/config/training.py index 6078db7b..1d7092d8 100644 --- a/eole/config/training.py +++ b/eole/config/training.py @@ -259,9 +259,11 @@ class TrainingConfig( ) lm_prior_lambda: float = Field(default=0.0, description="LM Prior Lambda") lm_prior_tau: float = Field(default=1.0, description="LM Prior Tau") - - estim_loss_lambda: float = Field( - default=1.0, description="Weight applied to estimator loss" + estim_loss_lambda: List[float] = Field( + default=[1.0], description="Weight applied to estimator loss" + ) + estim_loss_lambda_steps: List[int] = Field( + default=[0], description="Steps at which estimator loss lambda changes" ) score_threshold: float = Field( default=0.68, description="Threshold to filterout data" diff --git a/eole/trainer.py b/eole/trainer.py index dd43eb5a..ed468ed7 100644 --- a/eole/trainer.py +++ b/eole/trainer.py @@ -38,6 +38,7 @@ def build_trainer(config, device_id, model, vocabs, optim, model_saver=None): train_loss = LossCompute.from_config(config, model, vocabs["tgt"]) valid_loss = LossCompute.from_config(config, model, vocabs["tgt"], train=False) estim_loss_lambda = config.training.estim_loss_lambda + estim_loss_lambda_steps = config.training.estim_loss_lambda_steps scoring_preparator = ScoringPreparator(vocabs, config) validset_transforms = getattr(config.data.get("valid", None), "transforms", None) @@ -104,6 +105,7 @@ def build_trainer(config, device_id, model, vocabs, optim, model_saver=None): dropout_steps=dropout_steps, zero_out_prompt_loss=zero_out_prompt_loss, estim_loss_lambda=estim_loss_lambda, + estim_loss_lambda_steps=estim_loss_lambda_steps, ) return trainer @@ -147,7 +149,8 @@ class Trainer(object): dropout_steps (list): dropout values scheduling in steps. zero_out_prompt_loss (bool): whether to zero-out the prompt loss (mostly for LLM finetuning). - estim_loss_lambda (float): weight applied to estimator loss""" + estim_loss_lambda (List[float]): weight applied to estimator loss + estim_loss_lambda_steps (List[int]): steps to apply to estimator values""" def __init__( self, @@ -175,13 +178,16 @@ def __init__( attention_dropout=[0.1], dropout_steps=[0], zero_out_prompt_loss=False, - estim_loss_lambda=1.0, + estim_loss_lambda=[1.0], + estim_loss_lambda_steps=[0], ): # Basic attributes. self.model = model self.train_loss = train_loss - self.estim_loss_lambda = estim_loss_lambda + self.estim_loss_lambda_l = estim_loss_lambda + self.estim_loss_lambda = estim_loss_lambda[0] + self.estim_loss_lambda_steps = estim_loss_lambda_steps self.valid_loss = valid_loss self.scoring_preparator = scoring_preparator @@ -241,6 +247,15 @@ def _maybe_update_dropout(self, step): % (self.dropout[i], self.attention_dropout[i], step) ) + def _maybe_update_estim_lambda(self, step): + for i in range(len(self.estim_loss_lambda_steps)): + if step > 1 and step == self.estim_loss_lambda_steps[i] + 1: + self.estim_loss_lambda = self.estim_loss_lambda_l[i] + logger.info( + "Updated estimator lambda to %f at step %d" + % (self.estim_loss_lambda_l[i], step) + ) + def _accum_batches(self, iterator): batches = [] normalization = 0 @@ -318,6 +333,7 @@ def train( step = self.optim.training_step # UPDATE DROPOUT self._maybe_update_dropout(step) + self._maybe_update_estim_lambda(step) if self.n_gpu > 1 and self.parallel_mode == "data_parallel": normalization = sum(