Skip to content

Commit

Permalink
decompose logging and rename helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-chen-uni committed Aug 26, 2024
1 parent 04d20fa commit a1fa41c
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions llmfoundry/callbacks/kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
self.loss_window = deque(maxlen=self.window_size)
self.loss_cap = MAX_LOSS_CAP

def detect_loss_spike(self, train_loss: float, running_loss_avg: float):
def _detect_loss_spike(self, train_loss: float, running_loss_avg: float):
# Train loss is an outlier
if train_loss >= running_loss_avg * self.outlier_multiplier:
self.outlier_counter += 1
Expand All @@ -82,7 +82,7 @@ def detect_loss_spike(self, train_loss: float, running_loss_avg: float):
self.outlier_counter = 0
return False

def detect_high_losses(self, current_step: int):
def _detect_high_losses(self, current_step: int):
# Half of the running losses are greater than our "high loss" threshold, after an initial buffer period
if (current_step >= self.window_size * 2) and (
sum(1 for loss in self.loss_window if loss > self.loss_cap) >=
Expand All @@ -93,34 +93,30 @@ def detect_high_losses(self, current_step: int):
)
return True
return False

def handle_loss_spike(
self, logger: Logger, running_loss_avg: float
) -> None:

def _log_metadata(self, logger: Logger, key: str, message: str) -> None:
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({
'loss_spike':
f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.',
'loss_window':
list(self.loss_window),
key: message,
'loss_window': list(self.loss_window),
})

def _handle_loss_spike(
self, logger: Logger, running_loss_avg: float
) -> None:
message = f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.'
self._log_metadata(logger, 'loss_spike', message)
if not self.log_only:
raise LossSpikeError(
outlier_multiplier=self.outlier_multiplier,
running_loss_avg=round(running_loss_avg),
outlier_counter=self.outlier_counter,
)

def handle_high_losses(self, logger: Logger) -> None:
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({
'high_loss':
f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.',
'loss_window':
list(self.loss_window),
})
def _handle_high_losses(self, logger: Logger) -> None:
message = f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.'
self._log_metadata(logger, 'high_loss', message)
if not self.log_only:
raise HighLossError(
loss_cap=self.loss_cap,
Expand Down

0 comments on commit a1fa41c

Please sign in to comment.