Skip to content

Commit

Permalink
Log exception on inactivity callback (#1194)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjanezhang authored May 10, 2024
1 parent 983234d commit 994209c
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 4 deletions.
12 changes: 8 additions & 4 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@
from llmfoundry.callbacks.eval_output_logging_callback import EvalOutputLogging
from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer
from llmfoundry.callbacks.log_mbmoe_tok_per_expert_callback import \
MegaBlocksMoE_TokPerExpert
from llmfoundry.callbacks.monolithic_ckpt_callback import \
MonolithicCheckpointSaver
from llmfoundry.callbacks.log_mbmoe_tok_per_expert_callback import (
MegaBlocksMoE_TokPerExpert,
)
from llmfoundry.callbacks.monolithic_ckpt_callback import (
MonolithicCheckpointSaver,
)
from llmfoundry.callbacks.resumption_callbacks import (
GlobalLRScaling,
LayerFreezing,
)
from llmfoundry.callbacks.run_timeout_callback import RunTimeoutCallback
from llmfoundry.callbacks.scheduled_gc_callback import ScheduledGarbageCollector
from llmfoundry.registry import callbacks, callbacks_with_config

Expand All @@ -47,6 +50,7 @@
callbacks.register('oom_observer', func=OOMObserver)
callbacks.register('eval_output_logging', func=EvalOutputLogging)
callbacks.register('mbmoe_tok_per_expert', func=MegaBlocksMoE_TokPerExpert)
callbacks.register('run_timeout', func=RunTimeoutCallback)

callbacks_with_config.register('async_eval', func=AsyncEval)
callbacks_with_config.register('curriculum_learning', func=CurriculumLearning)
Expand Down
58 changes: 58 additions & 0 deletions llmfoundry/callbacks/run_timeout_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import logging
import os
import signal
import threading
from typing import Optional

from composer import Callback, Logger, State
from composer.loggers import MosaicMLLogger

from llmfoundry.utils.exceptions import RunTimeoutError

log = logging.getLogger(__name__)


def _timeout(timeout: int, mosaicml_logger: Optional[MosaicMLLogger] = None):
log.error(f'Timeout after {timeout} seconds of inactivity after fit_end.',)
if mosaicml_logger is not None:
mosaicml_logger.log_exception(RunTimeoutError(timeout=timeout))
os.kill(os.getpid(), signal.SIGINT)


class RunTimeoutCallback(Callback):

def __init__(
self,
timeout: int = 1800,
):
self.timeout = timeout
self.mosaicml_logger: Optional[MosaicMLLogger] = None
self.timer: Optional[threading.Timer] = None

def init(self, state: State, logger: Logger):
for callback in state.callbacks:
if isinstance(callback, MosaicMLLogger):
self.mosaicml_logger = callback

def _reset(self):
if self.timer is not None:
self.timer.cancel()
self.timer = None

def _timeout(self):
self._reset()
self.timer = threading.Timer(
self.timeout,
_timeout,
[self.timeout, self.mosaicml_logger],
)
self.timer.daemon = True
self.timer.start()

def fit_end(self, state: State, logger: Logger):
del state
del logger
self._timeout()
9 changes: 9 additions & 0 deletions llmfoundry/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,12 @@ def __init__(self, dataset_name: str, split: str) -> None:
self.split = split
message = f'Your dataset (name={dataset_name}, split={split}) is misconfigured. Please check your dataset format and make sure you can load your dataset locally.'
super().__init__(message)


class RunTimeoutError(RuntimeError):
"""Error thrown when a run times out."""

def __init__(self, timeout: int) -> None:
self.timeout = timeout
message = f'Run timed out after {timeout} seconds.'
super().__init__(message)

0 comments on commit 994209c

Please sign in to comment.