Skip to content

Commit

Permalink
decompose spike/high loss handling
Browse files Browse the repository at this point in the history
  • Loading branch information
joyce-chen-uni committed Aug 26, 2024
1 parent e3c00b3 commit 04d20fa
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 33 deletions.
79 changes: 46 additions & 33 deletions llmfoundry/callbacks/kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
__all__ = ['KillLossSpike']

MIN_WINDOW_SIZE = 100
MAX_LOSS_CAP = 10


@experimental_class('KillLossSpike')
Expand All @@ -41,7 +42,7 @@ class KillLossSpike(Callback):
the current window. Default is 2.
window_size (int): The size of the rolling window used to track recent losses. This is set to 1/20 of the total training batches, with a minimum of 100 steps.
loss_cap (int): The maximum allowable loss. If the training loss consistently exceeds this value,
it is considered a diverging or unstable run. This is set to the maximum loss from the first window of losses.
it is considered a diverging or unstable run. This is set to the maximum loss from the first window of losses, with a maximum of 10.
Raises:
LossSpikeError: If log_only is False and a loss spike or persistently high loss is detected, this error is
Expand All @@ -61,7 +62,7 @@ def __init__(
self.outlier_counter = 0
self.window_size = MIN_WINDOW_SIZE
self.loss_window = deque(maxlen=self.window_size)
self.loss_cap = float('inf')
self.loss_cap = MAX_LOSS_CAP

def detect_loss_spike(self, train_loss: float, running_loss_avg: float):
# Train loss is an outlier
Expand Down Expand Up @@ -93,17 +94,54 @@ def detect_high_losses(self, current_step: int):
return True
return False

def handle_loss_spike(
self, logger: Logger, running_loss_avg: float
) -> None:
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({
'loss_spike':
f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.',
'loss_window':
list(self.loss_window),
})
if not self.log_only:
raise LossSpikeError(
outlier_multiplier=self.outlier_multiplier,
running_loss_avg=round(running_loss_avg),
outlier_counter=self.outlier_counter,
)

def handle_high_losses(self, logger: Logger) -> None:
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({
'high_loss':
f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.',
'loss_window':
list(self.loss_window),
})
if not self.log_only:
raise HighLossError(
loss_cap=self.loss_cap,
window_size=self.window_size,
)

def fit_start(self, state: State, logger: Logger) -> None:
# Set the window size to a fraction of the total number of training batches for the run, minimum 100 batches.
if state.max_duration is not None:
if state.max_duration.unit == TimeUnit.EPOCH and state.dataloader_len is not None:
self.window_size = max(
MIN_WINDOW_SIZE,
round(float(state.dataloader_len * state.max_duration.value / 20)),
self.window_size,
round(
float(
state.dataloader_len * state.max_duration.value / 20
)
),
)
elif state.max_duration.unit == TimeUnit.BATCH or state.max_duration.unit == TimeUnit.TOKEN:
self.window_size = max(
MIN_WINDOW_SIZE,
self.window_size,
round(float(state.max_duration.value / 20)),
)
self.loss_window = deque(maxlen=self.window_size)
Expand All @@ -129,40 +167,15 @@ def batch_end(self, state: State, logger: Logger) -> None:

# Set the loss cap to the maximum loss from the first loss window
if current_step == self.window_size:
self.loss_cap = max(self.loss_window)
self.loss_cap = max(max(self.loss_window), self.loss_cap)

running_loss_avg = float(np.mean(self.loss_window))
log.info(f'Running loss average: {running_loss_avg}')

if self.detect_loss_spike(train_loss, running_loss_avg):
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({
'loss_spike':
f'Training loss spike detected for {self.outlier_counter} consecutive steps. Consider stopping this run and resubmitting with a lower learning rate.',
'loss_window':
list(self.loss_window),
})
if not self.log_only:
raise LossSpikeError(
outlier_multiplier=self.outlier_multiplier,
running_loss_avg=round(running_loss_avg),
outlier_counter=self.outlier_counter,
)
self.handle_loss_spike(logger, running_loss_avg)

elif self.detect_high_losses(current_step):
for destination in logger.destinations:
if isinstance(destination, MosaicMLLogger):
destination.log_metadata({
'high_loss':
f'Persistently high (>{self.loss_cap}) training losses detected. Consider stopping this run and resubmitting with a lower learning rate.',
'loss_window':
list(self.loss_window),
})
if not self.log_only:
raise HighLossError(
loss_cap=self.loss_cap,
window_size=self.window_size,
)
self.handle_high_losses(logger)

self.loss_window.append(train_loss)
26 changes: 26 additions & 0 deletions tests/callbacks/test_kill_loss_spike_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,32 @@ def test_no_error_raised_with_log_only_true(self, _):
except Exception as e:
self.fail(f'batch_end raised an exception {e} with log_only=True')

@patch('llmfoundry.callbacks.kill_loss_spike_callback.log')
def test_error_raised_with_log_only_false(self, _):
build_tiny_mpt = MagicMock()
build_tiny_mpt.return_value = MagicMock()
state = State(
model=build_tiny_mpt(loss_fn='torch_crossentropy'),
rank_zero_seed=0,
run_name='test_state',
device=DeviceCPU(),
)
state.loss = torch.tensor(4)
state.timestamp = Timestamp(batch=21)
logger = Logger(state, destinations=[MosaicMLLogger()])

# Loss spike detection should trigger
self.callback.outlier_counter = 4
self.callback.loss_window = deque([2] * 10, maxlen=10)
self.callback.log_only = False

result = self.callback.detect_loss_spike(state.loss.item(), 2)
self.assertTrue(result)

# batch_end should raise an error due to log_only=False
with self.assertRaises(LossSpikeError):
self.callback.batch_end(state, logger)

@patch('llmfoundry.callbacks.kill_loss_spike_callback.log')
def test_detect_high_losses_no_high_losses(self, _):
self.callback.loss_window = deque([2] * 10, maxlen=10)
Expand Down

0 comments on commit 04d20fa

Please sign in to comment.