Skip to content

Commit

Permalink
estim lambda scheduler (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s authored Jun 20, 2024
1 parent 799126e commit 187b36d
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
8 changes: 5 additions & 3 deletions eole/config/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,11 @@ class TrainingConfig(
)
lm_prior_lambda: float = Field(default=0.0, description="LM Prior Lambda")
lm_prior_tau: float = Field(default=1.0, description="LM Prior Tau")

estim_loss_lambda: float = Field(
default=1.0, description="Weight applied to estimator loss"
estim_loss_lambda: List[float] = Field(
default=[1.0], description="Weight applied to estimator loss"
)
estim_loss_lambda_steps: List[int] = Field(
default=[0], description="Steps at which estimator loss lambda changes"
)
score_threshold: float = Field(
default=0.68, description="Threshold to filterout data"
Expand Down
22 changes: 19 additions & 3 deletions eole/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def build_trainer(config, device_id, model, vocabs, optim, model_saver=None):
train_loss = LossCompute.from_config(config, model, vocabs["tgt"])
valid_loss = LossCompute.from_config(config, model, vocabs["tgt"], train=False)
estim_loss_lambda = config.training.estim_loss_lambda
estim_loss_lambda_steps = config.training.estim_loss_lambda_steps

scoring_preparator = ScoringPreparator(vocabs, config)
validset_transforms = getattr(config.data.get("valid", None), "transforms", None)
Expand Down Expand Up @@ -104,6 +105,7 @@ def build_trainer(config, device_id, model, vocabs, optim, model_saver=None):
dropout_steps=dropout_steps,
zero_out_prompt_loss=zero_out_prompt_loss,
estim_loss_lambda=estim_loss_lambda,
estim_loss_lambda_steps=estim_loss_lambda_steps,
)
return trainer

Expand Down Expand Up @@ -147,7 +149,8 @@ class Trainer(object):
dropout_steps (list): dropout values scheduling in steps.
zero_out_prompt_loss (bool): whether to zero-out the prompt loss
(mostly for LLM finetuning).
estim_loss_lambda (float): weight applied to estimator loss"""
estim_loss_lambda (List[float]): weight applied to estimator loss
estim_loss_lambda_steps (List[int]): steps to apply to estimator values"""

def __init__(
self,
Expand Down Expand Up @@ -175,13 +178,16 @@ def __init__(
attention_dropout=[0.1],
dropout_steps=[0],
zero_out_prompt_loss=False,
estim_loss_lambda=1.0,
estim_loss_lambda=[1.0],
estim_loss_lambda_steps=[0],
):
# Basic attributes.

self.model = model
self.train_loss = train_loss
self.estim_loss_lambda = estim_loss_lambda
self.estim_loss_lambda_l = estim_loss_lambda
self.estim_loss_lambda = estim_loss_lambda[0]
self.estim_loss_lambda_steps = estim_loss_lambda_steps
self.valid_loss = valid_loss

self.scoring_preparator = scoring_preparator
Expand Down Expand Up @@ -241,6 +247,15 @@ def _maybe_update_dropout(self, step):
% (self.dropout[i], self.attention_dropout[i], step)
)

def _maybe_update_estim_lambda(self, step):
for i in range(len(self.estim_loss_lambda_steps)):
if step > 1 and step == self.estim_loss_lambda_steps[i] + 1:
self.estim_loss_lambda = self.estim_loss_lambda_l[i]
logger.info(
"Updated estimator lambda to %f at step %d"
% (self.estim_loss_lambda_l[i], step)
)

def _accum_batches(self, iterator):
batches = []
normalization = 0
Expand Down Expand Up @@ -318,6 +333,7 @@ def train(
step = self.optim.training_step
# UPDATE DROPOUT
self._maybe_update_dropout(step)
self._maybe_update_estim_lambda(step)

if self.n_gpu > 1 and self.parallel_mode == "data_parallel":
normalization = sum(
Expand Down

0 comments on commit 187b36d

Please sign in to comment.