diff --git a/docs/model-dev-guide/hyperparameter/search-methods/hp-adaptive-asha.rst b/docs/model-dev-guide/hyperparameter/search-methods/hp-adaptive-asha.rst index 31c532b70af..1e7846d6295 100644 --- a/docs/model-dev-guide/hyperparameter/search-methods/hp-adaptive-asha.rst +++ b/docs/model-dev-guide/hyperparameter/search-methods/hp-adaptive-asha.rst @@ -25,7 +25,7 @@ Resource budget: `). Note that the searcher will expect this metric to appear in validation metrics reported by the model. This quantity is domain-specific and should roughly reflect the number of minibatches the model must be trained on for it to converge on the - data set. For users who would like to determine this number experimentally, train a model with + dataset. For users who would like to determine this number experimentally, train a model with reasonable hyperparameters using the ``single`` search method. - ``max_trials``: This indicates the total number of hyperparameter settings that will be evaluated diff --git a/docs/reference/experiment-config-reference.rst b/docs/reference/experiment-config-reference.rst index 716f027ef11..fbafc3e9dbc 100644 --- a/docs/reference/experiment-config-reference.rst +++ b/docs/reference/experiment-config-reference.rst @@ -335,8 +335,8 @@ Optional. Specifies the minimum frequency at which validation should be run for epochs: 2 - :class:`~determined.pytorch.deepspeed.DeepSpeedTrial` and - :class:`~determined.keras.TFKerasTrial`: If this is in the unit of epochs, - :ref:`records_per_epoch ` must be specified. + :class:`~determined.keras.TFKerasTrial`: If this is in the unit of epochs, ``records_per_epoch`` + must be specified. .. _experiment-config-perform-initial-validation: @@ -377,7 +377,7 @@ Optional. Specifies the minimum frequency for running checkpointing for each tri - :class:`~determined.pytorch.deepspeed.DeepSpeedTrial` and :class:`~determined.keras.TFKerasTrial`: If the unit is in epochs, you must also specify - :ref:`records_per_epoch `. + ``records_per_epoch``. ``checkpoint_policy`` ===================== diff --git a/docs/reference/training/_index.rst b/docs/reference/training/_index.rst index 84ace82203d..aa221c7b906 100644 --- a/docs/reference/training/_index.rst +++ b/docs/reference/training/_index.rst @@ -15,6 +15,7 @@ - :ref:`det.pytorch.samplers ` - :ref:`det.pytorch.deepspeed ` - :ref:`det.keras ` +- :ref:`det.transformers ` ******************************* Experiment Configuration File diff --git a/docs/reference/training/api-transformers-reference.rst b/docs/reference/training/api-transformers-reference.rst new file mode 100644 index 00000000000..8ded8cf68c6 --- /dev/null +++ b/docs/reference/training/api-transformers-reference.rst @@ -0,0 +1,11 @@ +.. _transformers-reference: + +#################################### + ``det.transformers`` API Reference +#################################### + +***************************************** + ``determined.transformers.DetCallback`` +***************************************** + +.. autoclass:: determined.transformers.DetCallback diff --git a/e2e_tests/tests/config.py b/e2e_tests/tests/config.py index bedac9943e0..cfac33d5fd9 100644 --- a/e2e_tests/tests/config.py +++ b/e2e_tests/tests/config.py @@ -1,6 +1,6 @@ import os import pathlib -from typing import Any, Dict, List, Union +from typing import Any, Dict, List from determined.common import api, util diff --git a/e2e_tests/tests/fixtures/hpc/embedded-single-quote.yaml b/e2e_tests/tests/fixtures/hpc/embedded-single-quote.yaml index 8e4345eebc8..0c2e8efc297 100644 --- a/e2e_tests/tests/fixtures/hpc/embedded-single-quote.yaml +++ b/e2e_tests/tests/fixtures/hpc/embedded-single-quote.yaml @@ -5,7 +5,5 @@ data: searcher: name: single metric: error - max_length: - batches: 1000 max_restarts: 0 entrypoint: python3 data_validator.py diff --git a/examples/hf_trainer_api/hf_image_classification/adaptive.yaml b/examples/hf_trainer_api/hf_image_classification/adaptive.yaml index 4407ebf100d..59850ae5150 100644 --- a/examples/hf_trainer_api/hf_image_classification/adaptive.yaml +++ b/examples/hf_trainer_api/hf_image_classification/adaptive.yaml @@ -9,8 +9,8 @@ resources: slots_per_trial: 2 searcher: name: adaptive_asha - max_length: - batches: 100 + time_metric: batches + max_time: 100 max_trials: 64 max_rungs: 4 divisor: 4 diff --git a/examples/hf_trainer_api/hf_image_classification/const.yaml b/examples/hf_trainer_api/hf_image_classification/const.yaml index 2a3b95535fa..fcf6a9d3844 100644 --- a/examples/hf_trainer_api/hf_image_classification/const.yaml +++ b/examples/hf_trainer_api/hf_image_classification/const.yaml @@ -9,8 +9,6 @@ resources: slots_per_trial: 1 searcher: name: single - max_length: - batches: 100 metric: eval_loss hyperparameters: training_arguments: diff --git a/examples/hf_trainer_api/hf_image_classification/const_epochs.yaml b/examples/hf_trainer_api/hf_image_classification/const_epochs.yaml index 49425d854af..2fb6a69f5a4 100644 --- a/examples/hf_trainer_api/hf_image_classification/const_epochs.yaml +++ b/examples/hf_trainer_api/hf_image_classification/const_epochs.yaml @@ -10,8 +10,6 @@ resources: records_per_epoch: 1000 searcher: name: single - max_length: - epochs: 5 metric: eval_loss hyperparameters: training_arguments: diff --git a/examples/hf_trainer_api/hf_image_classification/deepspeed.yaml b/examples/hf_trainer_api/hf_image_classification/deepspeed.yaml index f60ec4d218d..698d68f8bba 100644 --- a/examples/hf_trainer_api/hf_image_classification/deepspeed.yaml +++ b/examples/hf_trainer_api/hf_image_classification/deepspeed.yaml @@ -11,8 +11,6 @@ resources: slots_per_trial: 2 searcher: name: single - max_length: - batches: 100 metric: eval_loss hyperparameters: deepspeed_config: ds_configs/ds_config_stage_1.json diff --git a/examples/hf_trainer_api/hf_image_classification/distributed.yaml b/examples/hf_trainer_api/hf_image_classification/distributed.yaml index a9ea4ca154b..fe74f1ec1b7 100644 --- a/examples/hf_trainer_api/hf_image_classification/distributed.yaml +++ b/examples/hf_trainer_api/hf_image_classification/distributed.yaml @@ -9,8 +9,6 @@ resources: slots_per_trial: 2 searcher: name: single - max_length: - batches: 100 metric: eval_loss hyperparameters: training_arguments: diff --git a/examples/hf_trainer_api/hf_language_modeling/adaptive.yaml b/examples/hf_trainer_api/hf_language_modeling/adaptive.yaml index 7799582879a..946aa15f9f1 100644 --- a/examples/hf_trainer_api/hf_language_modeling/adaptive.yaml +++ b/examples/hf_trainer_api/hf_language_modeling/adaptive.yaml @@ -9,8 +9,8 @@ resources: slots_per_trial: 2 searcher: name: adaptive_asha - max_length: - batches: 100 + time_metric: batches + max_time: 100 max_trials: 64 max_rungs: 4 divisor: 4 diff --git a/examples/hf_trainer_api/hf_language_modeling/const.yaml b/examples/hf_trainer_api/hf_language_modeling/const.yaml index 294504aed07..e340834457c 100644 --- a/examples/hf_trainer_api/hf_language_modeling/const.yaml +++ b/examples/hf_trainer_api/hf_language_modeling/const.yaml @@ -9,8 +9,6 @@ resources: slots_per_trial: 1 searcher: name: single - max_length: - batches: 100 metric: eval_loss hyperparameters: training_arguments: diff --git a/examples/hf_trainer_api/hf_language_modeling/const_epochs.yaml b/examples/hf_trainer_api/hf_language_modeling/const_epochs.yaml index 75b44dc5f97..b544db63620 100644 --- a/examples/hf_trainer_api/hf_language_modeling/const_epochs.yaml +++ b/examples/hf_trainer_api/hf_language_modeling/const_epochs.yaml @@ -10,8 +10,6 @@ resources: records_per_epoch: 1000 searcher: name: single - max_length: - epochs: 5 metric: eval_loss hyperparameters: training_arguments: diff --git a/examples/hf_trainer_api/hf_language_modeling/deepspeed.yaml b/examples/hf_trainer_api/hf_language_modeling/deepspeed.yaml index 66ff58889fc..8facb3c47ac 100644 --- a/examples/hf_trainer_api/hf_language_modeling/deepspeed.yaml +++ b/examples/hf_trainer_api/hf_language_modeling/deepspeed.yaml @@ -11,8 +11,6 @@ resources: slots_per_trial: 2 searcher: name: single - max_length: - batches: 100 metric: eval_loss hyperparameters: deepspeed_config: ds_configs/ds_config_stage_1.json diff --git a/examples/hf_trainer_api/hf_language_modeling/distributed.yaml b/examples/hf_trainer_api/hf_language_modeling/distributed.yaml index c305d98f490..08b62b79788 100644 --- a/examples/hf_trainer_api/hf_language_modeling/distributed.yaml +++ b/examples/hf_trainer_api/hf_language_modeling/distributed.yaml @@ -9,8 +9,6 @@ resources: slots_per_trial: 2 searcher: name: single - max_length: - batches: 100 metric: eval_loss hyperparameters: training_arguments: diff --git a/harness/determined/transformers/_hf_callback.py b/harness/determined/transformers/_hf_callback.py index b9840db5f55..e4ccc3a819b 100644 --- a/harness/determined/transformers/_hf_callback.py +++ b/harness/determined/transformers/_hf_callback.py @@ -8,10 +8,28 @@ import determined as det -logger = logging.getLogger("determined.transformers") +logger = logging.getLogger("det.transformers") class DetCallback(transformers.TrainerCallback): # type: ignore + """ + ``DetCallback`` integrates a training loop built around ``transformers.Trainer`` with the + Determined cluster. It reports metrics, uploads checkpoints, and handles preemption signals. + It also automatically restores training from the latest checkpoint after pauses or crashes. + + Simply include ``DetCallback`` as in the list of ``callbacks`` that you pass to your + ``Trainer``. + + Args: + core_context: the result of a ``det.core.init()`` call. + args: ``TrainingArgs`` from a ``transformers.HfArgumentParser``, the same ``args`` to be + passed to the ``Trainer``. + filter_metrics: a list of metric names to report to Determined. Default: ``None`` (all + metrics are reported). + user_data: an optional dict of metadata to be stored in every checkpoint. + Default: ``None``. + """ + def __init__( self, core_context: det.core.Context, @@ -20,32 +38,125 @@ def __init__( user_data: Optional[Dict[str, Any]] = None, ) -> None: super().__init__() - self.core_context = core_context - self.filter_metrics = filter_metrics self.user_data = user_data + + self.last_train_metrics = -1 + self.last_eval_metrics = -1 + self.last_save = -1 + self.last_progress = 0 + + info = det.get_cluster_info() + if not info: + raise RuntimeError("det.transformers.DetCallback must be run on a Determined cluster") + self.info = info + self.load_last_checkpoint(args) - self.last_metrics: Dict[str, float] = {"train_step": -1, "eval_step": -1} - self.searcher_ops = self.core_context.searcher.operations() - self.current_op = next(self.searcher_ops) - self.updating_searcher = False - - cluster_info = det.get_cluster_info() - assert ( - cluster_info - ), "Could not find `cluster_info`, the HF Callback must be run on a Determined Cluster" - searcher_config = cluster_info.trial._config["searcher"] - self.searcher_metric = searcher_config["metric"] - # Custom searchers have a different config structure which need to be handled differently - if searcher_config["name"] == "custom": - self.searcher_unit = "batches" - self.searcher_max_length = self.current_op.length + self.searcher_metric = None + self.time_metric = None + if self.info.task_type == "TRIAL": + searcher_config = self.info.trial._config["searcher"] + self._check_searcher_config(searcher_config, args) + self.searcher_metric = searcher_config["metric"] + self.time_metric = searcher_config.get("time_metric") + # Don't allow filtering of the searcher or time_metric metrics. + if self.filter_metrics: + self.filter_metrics.append(self.searcher_metric) + if self.time_metric: + self.filter_metrics.append(self.time_metric) + + # Undocumented workarounds in case forcing the checkpoint and validations at the end of + # non-preempted training is a bad idea somehow. + self._force_final_save = True + self._force_final_evaluate = True + + def load_last_checkpoint(self, args: transformers.TrainingArguments) -> None: + latest_checkpoint = self.info.latest_checkpoint + if latest_checkpoint is None: + return + if args.overwrite_output_dir is True: + logger.info( + "Skipping downloading last checkpoint from Determined due " + "to overwrite_output_dir=True." + ) + return + + # To resume DeepSpeed, each node requires ALL sharded model/optimizer states, + # so we can skip using selector and just download all files. + self.core_context.checkpoint.download(latest_checkpoint, args.output_dir) + + checkpoint_path = trainer_utils.get_last_checkpoint(args.output_dir) + args.resume_from_checkpoint = checkpoint_path + + logger.info(f"Latest checkpoint downloaded to {checkpoint_path}.") + + def _check_searcher_config( + self, cfg: Dict[str, Any], args: transformers.TrainingArguments + ) -> None: + if args.max_steps > -1: + args_unit = "batches" + args_len = args.max_steps + len_arg = "--max_steps" else: - self.searcher_unit = list(searcher_config["max_length"].keys())[0] - self.searcher_max_length = list(searcher_config["max_length"].values())[0] - self._check_searcher_compatibility(args) + args_unit = "epochs" + args_len = args.num_train_epochs + len_arg = "--num_train_epochs" + + if isinstance(cfg.get("max_length"), int): + # Legacy searcher config (unitless). Has never been supported, actually. + raise ValueError( + "HF trainer no longer respects the deprecated searcher.max_length " + "field. searcher.max_length is deprecated; please remove it and rely " + f"on {len_arg} instead to avoid ambiguous training specifications." + ) + elif isinstance(cfg.get("max_length"), dict): + # Legacy searcher config; max_length must match provided args. + search_unit, search_len = next(iter(cfg["max_length"].items())) + if (search_unit, search_len) != (args_unit, args_len): + raise ValueError( + "HF trainer units does not match configured searcher.max_length " + f"({args_unit}={args_len} != {search_unit}={search_len}). The " + "searcher.max_length field is deprecated; please remove it and avoid " + "ambiguous training specifications." + ) + elif cfg["name"] in ["adaptive_asha", "async_halving"]: + # ASHA search: check time_metric and max_time are sane. + self.required_metrics.append(cfg["time_metric"]) + search_unit = cfg["time_metric"] + search_len = cfg["max_time"] + if search_unit not in ("batches", "epochs"): + self.required_metrics.append(search_unit) + elif (search_unit, search_len) != (args_unit, args_len): + name = cfg["name"] + raise ValueError( + "HF trainer units does not match configured the max_time configured for " + f"{name} searcher ({args_unit}={args_len} != {search_unit}={search_len}. " + f"Please update one of the searcher.max_time config field or the {len_arg} " + "to match the other." + ) + + def _check_eval_metrics(self, metrics: Dict[str, Any]) -> None: + search_ok = self.searcher_metric is None or self.searcher_metric in metrics + time_ok = self.time_metric is None or self.time_metric in metrics + if not search_ok and not time_ok: + raise ValueError( + f"Searcher metric '{self.searcher_metric}' set by searcher.metric config field " + f"and time metric '{self.time_metric}' from searcher.time_metric config field are " + "both missing; you must emit those metrics for the hyperparameter search to work." + ) + if not search_ok: + raise ValueError( + f"Searcher metric '{self.searcher_metric}' set by searcher.metric config field " + "is missing; you must emit that metric for features like hyperparameter search, " + "checkpoint garbage collection, and selecting the best checkpoint to work." + ) + if not time_ok: + raise ValueError( + f"Time metric '{self.time_metric}' set by searcher.time_metric config field is " + "missing; you must emit that metric for the hyperparameter search to work." + ) def on_log( self, @@ -60,52 +171,56 @@ def on_log( return metrics, metric_type = self._get_metrics(logs) logger.debug(f"on_log metrics, global_step {state.global_step}", metrics) + metrics["batches"] = metrics.get("batches", state.global_step) + metrics["epochs"] = metrics.get("epochs", state.epoch) if metric_type == TRAIN: # Prevents reporting metrics for the same step twice. This happens after # training is completed and average training metrics are reported with # the same step as the in-progress training metrics. - if self.last_metrics["train_step"] != state.global_step: + if self.last_train_metrics != state.global_step: + self.last_train_metrics = state.global_step if state.is_world_process_zero: - self.core_context.train.report_training_metrics( - steps_completed=state.global_step, metrics=metrics + # Note: state.global_step represents steps_completed, not step index + self.core_context.train.report_metrics( + group="training", steps_completed=state.global_step, metrics=metrics ) - metrics["train_step"] = state.global_step elif metric_type == EVAL: # Prevents reporting metrics for the same step twice. This happens when # after-training evaluation is completed, and it is reported with the same # step as the last during-training evaluation. - if self.last_metrics["eval_step"] != state.global_step: + if self.last_eval_metrics != state.global_step: + self.last_eval_metrics = state.global_step if state.is_world_process_zero: - self.core_context.train.report_validation_metrics( - steps_completed=state.global_step, metrics=metrics + self._check_eval_metrics(metrics) + # Note: state.global_step represents steps_completed, not step index + self.core_context.train.report_metrics( + group="validation", steps_completed=state.global_step, metrics=metrics ) - metrics["eval_step"] = state.global_step else: logger.warning(f"Metrics not reported: metric type = {metric_type}.") - self.last_metrics.update(metrics) - - # Update searcher state after collecting the metrics. - if self.updating_searcher is True: - self._update_searcher(state, control) - - # If searcher is NOT being updated and preemption signal is received - # (e.g., by pausing experiment in the WebUI), notify Trainer (via TrainerControl) - # to save the checkpoint and stop training. - if self.updating_searcher is False and self.core_context.preempt.should_preempt(): + # If we've been preempted, save a checkpoint and shut down training. + if self.core_context.preempt.should_preempt(): control.should_training_stop = True - control.should_save = True + # Don't set control.should_save now, or it can trigger multiple saves, if we trigger + # in a training on_log and arrive here again in an evaluate on_log. We would not cause + # that to happen, but other callbacks could, such as if it were just naturally time for + # an evaluation. So just let the save-at-end logic handle it. def _get_metrics(self, logs: Dict[str, Any]) -> Tuple[Dict[str, Any], str]: - metrics = logs metric_type = get_metric_type(logs) - if self.filter_metrics: - metrics = {} - for k, v in logs.items(): - if any(m in k for m in self.filter_metrics) is True: - metrics[k] = v - + if not self.filter_metrics: + metrics = logs + else: + metrics = {k: v for k, v in logs.items() if any(m in k for m in self.filter_metrics)} + # Remove the default rounded 'epoch' metric. + metrics.pop("epoch", None) + # Also remove speed metrics. + speed_suffixes = ["_runtime", "_per_second", "_compilation_time"] + speed_metrics = [m for m in metrics if any(m.endswith(s) for s in speed_suffixes)] + for m in speed_metrics: + metrics.pop(m, None) return metrics, metric_type def on_save( @@ -115,25 +230,24 @@ def on_save( control: transformers.TrainerControl, **kwargs: Any, ) -> None: - info = det.get_cluster_info() - assert info - + self.last_save = state.global_step # local_path is where HF Trainer saves model and tokenizer in a given step. local_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") if state.is_world_process_zero: if self.user_data is not None: self._on_save_user_data(local_path) - det_checkpoint_metadata = { + metadata = { "steps_completed": state.global_step, - "trial_id": info.trial.trial_id, } + if self.info.task_type == "TRIAL": + metadata["trial_id"] = self.info.trial.trial_id def selector(x: str) -> bool: return x.startswith((f"checkpoint-{state.global_step}/", "runs/")) self.core_context.checkpoint.upload( - args.output_dir, metadata=det_checkpoint_metadata, shard=True, selector=selector + args.output_dir, metadata=metadata, shard=True, selector=selector ) def _on_save_user_data(self, save_path: str) -> None: @@ -145,28 +259,6 @@ def _on_save_user_data(self, save_path: str) -> None: with open(os.path.join(save_path, "my_data.json"), "w") as f: json.dump(self.user_data, f) - def load_last_checkpoint(self, args: transformers.TrainingArguments) -> None: - info = det.get_cluster_info() - assert info - - latest_checkpoint = info.latest_checkpoint - if latest_checkpoint is not None: - if args.overwrite_output_dir is True: - logger.info( - "Skip downloading last checkpoint from Determined due " - "to overwrite_output_dir=True." - ) - return - - # To resume DeepSpeed, each node requires ALL sharded model/optimizer states, - # so we can skip using selector and just download all files. - self.core_context.checkpoint.download(latest_checkpoint, args.output_dir) - - checkpoint_path = trainer_utils.get_last_checkpoint(args.output_dir) - args.resume_from_checkpoint = checkpoint_path - - logger.info(f"Latest checkpoint downloaded to {checkpoint_path}.") - def on_step_end( self, args: transformers.TrainingArguments, @@ -174,17 +266,14 @@ def on_step_end( control: transformers.TrainerControl, **kwargs: Any, ) -> None: - # state.epoch is not None only during training. - if state.epoch and self.searcher_unit == "batches": - if state.is_world_process_zero: - self.current_op.report_progress(state.global_step) - - if state.global_step >= self.current_op.length: - logger.info( - f"Max length of {self.current_op.length} steps reached for current " - f"searcher operation. Updating searcher." - ) - self._update_searcher(state, control) + if state.is_world_process_zero and args.max_steps > -1: + # There needs to be at least 1% increase in progress to report progress (maximum 100 + # report_progress API calls in per trial). + progress = state.global_step / args.max_steps + percent = int(progress * 100) + if percent > self.last_progress: + self.last_progress = percent + self.core_context.train.report_progress(progress) def on_epoch_end( self, @@ -193,95 +282,30 @@ def on_epoch_end( control: transformers.TrainerControl, **kwargs: Any, ) -> None: - # state.epoch is not None only during training. - if state.epoch and self.searcher_unit == "epochs": - if state.is_world_process_zero: - self.current_op.report_progress(state.epoch) - - if state.epoch >= self.current_op.length: - logger.info( - f"Max length of {state.epoch} epochs reached for current " - f"searcher operation. Updating searcher." - ) - self._update_searcher(state, control) - - def _update_searcher( - self, state: transformers.TrainerState, control: transformers.TrainerControl - ) -> None: - if self._metrics_reported(state.global_step) is False: - self._wait_for_metrics(control) - return - - if state.is_world_process_zero: - if self.last_metrics is None: - logger.warning( - "No training or evaluation metrics has been recorded. Please " - "check your settings for training metrics " - "(--logging_strategy and --logging_steps) or " - "evaluation metrics (--evaluation_strategy and --eval_steps). " - "Reporting trainer_state.best_metric to the searcher." - ) - searcher_metric = state.best_metric - elif self.searcher_metric not in self.last_metrics: - logger.warning( - f"Searcher metric {self.searcher_metric} from the yaml config file does " - "not match any of the recorded metrics " - f"in {self.last_metrics}. " - "Reporting trainer_state.best_metric to the searcher." - ) - searcher_metric = state.best_metric - else: - searcher_metric = self.last_metrics[self.searcher_metric] - - logger.info(f"Metric reported to searcher: {searcher_metric}") - self.current_op.report_completed(searcher_metric) - - self.updating_searcher = False + # Decide if we're about to shut down training. + is_end = False + if control.should_training_stop: + is_end = True + elif args.max_steps > -1: + is_end = state.global_step >= args.max_steps + else: + is_end = state.epoch >= args.num_train_epochs - try: - self.current_op = next(self.searcher_ops) - except StopIteration: - control.should_training_stop = True + # If training is ending, this is our last chance to ask for a eval and/or save. + if is_end: + # Avoid stale evaluate-at-end. + if state.global_step > self.last_eval_metrics: + # Also avoid evaluate-at-end if we have been preempted. + if self._force_final_evaluate and not self.core_context.preempt.should_preempt(): + control.should_evaluate = True + # Avoid stale save-at-end. + if state.global_step > self.last_save: + # You can't disable save-after-preemption. + if self._force_final_save or self.core_context.preempt.should_preempt(): + control.should_save = True - def _metrics_reported(self, step: int) -> bool: - return self.last_metrics["eval_step"] == step and self.last_metrics["train_step"] == step - - def _wait_for_metrics(self, control: transformers.TrainerControl) -> None: - # Notify Trainer (via transformers.TrainerControl) to: - # (1) log current training metrics, - # (2) evaluate the model and log evaluation metrics, - # (3) save the checkpoint. - # updating_searcher is as an internal flag that indicates we are - # in the process of updating the searcher with the current metrics. - control.should_log = True - control.should_evaluate = True - control.should_save = True - self.updating_searcher = True - - def _check_searcher_compatibility(self, args: transformers.TrainingArguments) -> None: - if self.searcher_unit == "batches": - if args.max_steps == -1: - self._raise_config_mismatch("epochs", args.num_train_epochs) - elif args.max_steps != self.searcher_max_length: - self._raise_config_mismatch("batches", args.max_steps) - elif self.searcher_unit == "epochs": - if args.max_steps != -1: - self._raise_config_mismatch("batches", args.max_steps) - elif args.num_train_epochs != self.searcher_max_length: - self._raise_config_mismatch("epochs", args.num_train_epochs) - - def _raise_config_mismatch( - self, - trainer_units: str, - trainer_len: float, - ) -> None: - raise ValueError( - f"HF trainer units {trainer_units}={trainer_len} MUST match searcher config " - f"{self.searcher_unit}={self.searcher_max_length}. " - f"Modify either --num_train_epochs for the training script or " - f"searcher.max_length.epochs in the experiment config so they are the same value " - f"(--max_steps and searcher.max_length.batches if using batches)." - ) + if state.is_world_process_zero and args.max_steps == -1: + self.core_context.train.report_progress(state.epoch / args.num_train_epochs) EVAL = "eval_" @@ -290,13 +314,10 @@ def _raise_config_mismatch( def get_metric_type(d: Dict[str, Any]) -> str: - for k, _ in d.items(): - if k.startswith(EVAL): - return EVAL - elif k.startswith(TEST): - return TEST - else: - return TRAIN + if any(k.startswith(EVAL) for k in d): + return EVAL + if any(k.startswith(TEST) for k in d): + return TEST return TRAIN diff --git a/harness/tests/experiment/transformers/__init__.py b/harness/tests/experiment/transformers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/harness/tests/experiment/transformers/test_callback.py b/harness/tests/experiment/transformers/test_callback.py new file mode 100644 index 00000000000..e1df54a1b65 --- /dev/null +++ b/harness/tests/experiment/transformers/test_callback.py @@ -0,0 +1,495 @@ +import pathlib +import re +from typing import Any, Callable, Dict, Optional, Tuple +from unittest import mock + +import numpy as np +import torch +import transformers + +import determined as det +import determined.transformers +from determined import core +from determined.common import storage +from tests.experiment import utils +from tests.launch import test_util + + +def mock_core_context( + path: str, events: utils.Events, distributed: Optional[core.DistributedContext] = None +) -> Tuple[core.Context, Callable[[], None]]: + """ + Returns a core_context and a set_preempt() callable. + + The core_context is partially mocked to support triggering preemption from test code and to log + all reports to the provided Events object. + """ + # Set up a functional DistributedContext. + distributed = distributed or core.DummyDistributedContext() + # Set up a functional CheckpointContext. + storage_manager = storage.SharedFSStorageManager(path) + + class DummyCheckpointContext(core.DummyCheckpointContext): + def _report_checkpoint( + self, + storage_id: str, + resources: Optional[Dict[str, int]] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: + events.append(("report_checkpoint", storage_id)) + super()._report_checkpoint(storage_id, resources, metadata) + + checkpoint = DummyCheckpointContext(distributed, storage_manager) + + # Mock everything else, logging report-like calls to events. + + def report_metrics(group: str, steps_completed: int, metrics: Any) -> None: + events.append((f"report_metrics:{group}:{steps_completed}", metrics)) + + def report_progress(progress: float) -> None: + fourdigits = "%.4f" % progress + events.append((f"report_progress:{fourdigits}", progress)) + + def set_status(status: str) -> None: + events.append((f"set_status:{status}", None)) + + preempted = False + + def should_preempt() -> bool: + nonlocal preempted + return preempted + + core_context = mock.Mock() + core_context.distributed = distributed + core_context.preempt.should_preempt.side_effect = should_preempt + core_context.checkpoint = checkpoint + core_context.train.report_metrics.side_effect = report_metrics + core_context.train.report_progress.side_effect = report_progress + core_context.train.set_status.side_effect = set_status + + def set_preempt() -> None: + nonlocal preempted + preempted = True + + return core_context, set_preempt + + +class MyOneVarModel(torch.nn.Linear): # type: ignore + """ + Subclass torch.nn.Linear with custom behaviors to be Transformers.Trainer-friendly. + """ + + def __init__(self) -> None: + super().__init__(1, 1, False) + self.weight.data.fill_(0) + self._loss_fn = torch.nn.MSELoss() + + # Signature must match key in dataset's output. + def forward(self, x: torch.Tensor, label_y: torch.Tensor) -> Dict[str, torch.Tensor]: + y = super().forward(x) + loss = self._loss_fn(y, label_y) + # We must return a dict with "loss" as a key. + # (technically a tuple with loss as the first element is also ok) + return {"loss": loss, "pred_y": y} + + +class OnesDataset(torch.utils.data.Dataset): + def __init__(self, dataset_len: int) -> None: + self.dataset_len = dataset_len + + def __len__(self) -> int: + return self.dataset_len + + def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: + # Key name must match model's .forward() signature. + return {"x": torch.Tensor([float(1)]), "label_y": torch.Tensor([float(1)])} + + +def compute_metrics(pred: transformers.EvalPrediction) -> Dict[str, float]: + # Return a mean absolute error as a metric. + return {"mae": np.abs(pred.predictions - pred.label_ids).mean()} + + +class DetCallbackForTesting(det.transformers.DetCallback): + def __init__(self, events: utils.Events, *args: Any, **kwargs: Any) -> None: + self.events = events + super().__init__(*args, **kwargs) + + def on_train_begin( + self, + args: transformers.TrainingArguments, + state: transformers.TrainerState, + control: transformers.TrainerControl, + **kwargs: Any, + ) -> None: + epoch = "%.4f" % state.epoch + self.events.append((f"on_train_begin:{state.global_step}:{epoch}", None)) + + def on_epoch_begin( + self, + args: transformers.TrainingArguments, + state: transformers.TrainerState, + control: transformers.TrainerControl, + **kwargs: Any, + ) -> None: + epoch = "%.4f" % state.epoch + self.events.append((f"on_epoch_begin:{state.global_step}:{epoch}", None)) + + def on_epoch_end( + self, + args: transformers.TrainingArguments, + state: transformers.TrainerState, + control: transformers.TrainerControl, + **kwargs: Any, + ) -> None: + epoch = "%.4f" % state.epoch + weight = kwargs["model"].weight.data.item() + self.events.append((f"before_epoch_end:{state.global_step}:{epoch}", weight)) + super().on_epoch_end(args, state, control) + self.events.append((f"after_epoch_end:{state.global_step}:{epoch}", weight)) + + def on_save( + self, + args: transformers.TrainingArguments, + state: transformers.TrainerState, + control: transformers.TrainerControl, + **kwargs: Any, + ) -> None: + epoch = "%.4f" % state.epoch + self.events.append((f"before_save:{state.global_step}:{epoch}", None)) + super().on_save(args, state, control) + self.events.append((f"after_save:{state.global_step}:{epoch}", None)) + + def on_evaluate( + self, + args: transformers.TrainingArguments, + state: transformers.TrainerState, + control: transformers.TrainerControl, + **kwargs: Any, + ) -> None: + epoch = "%.4f" % state.epoch + self.events.append((f"on_evaluate:{state.global_step}:{epoch}", None)) + + def on_train_end(self, *args: Any, **kwargs: Any) -> None: + self.events.append(("on_train_end", None)) + + +def do_train( + tmp_path: pathlib.Path, + force_final_save: Optional[bool] = None, + force_final_evaluate: Optional[bool] = None, + set_preempt_on_event: Optional[str] = None, + latest_checkpoint: Optional[str] = None, + **kwargs: Any, +) -> utils.Events: + args = transformers.TrainingArguments( + output_dir=str(tmp_path / "trainer"), disable_tqdm=True, **kwargs + ) + + with test_util.set_mock_cluster_info(["0.0.0.0"], 0, 1) as info: + info.trial._config = {"searcher": {"name": "single", "metric": "eval_mae"}} + info._latest_checkpoint = latest_checkpoint + + model = MyOneVarModel() + train_dataset = OnesDataset(64) + eval_dataset = OnesDataset(64) + + events = utils.Events() + core_context, set_preempt = mock_core_context(str(tmp_path / "ckpt"), events) + + if set_preempt_on_event: + # Configure a hook for Events that calls set_preempt() when a matching event arrives. + p = re.compile(set_preempt_on_event) + + def hook(summary: str, data: Any) -> None: + if p.search(summary): + set_preempt() + + events.hook = hook + + det_cb = DetCallbackForTesting(events, core_context, args) + if force_final_save is not None: + det_cb._force_final_save = force_final_save + if force_final_evaluate is not None: + det_cb._force_final_evaluate = force_final_evaluate + + t = transformers.Trainer( + model=model, + args=args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + compute_metrics=compute_metrics, + callbacks=[det_cb], + ) + # The call to train must specify the checkpoint. We do set args.resume_from_checkpoint in + # our DetCallback but it isn't automatically respected. + t.train(resume_from_checkpoint=args.resume_from_checkpoint) + + return events + + +def check_hf_metrics(metrics: Dict[str, Any]) -> None: + # We remove the default rounded 'epoch' metric, and the + assert "epoch" not in metrics, metrics + # We remove the speed metrics. + speed_suffixes = ["_runtime", "_per_second", "_compilation_time"] + assert not any(any(m.endswith(s) for s in speed_suffixes) for m in metrics), metrics + # We inject "epochs" and "batches" + assert "epochs" in metrics, metrics + assert "batches" in metrics, metrics + + +def test_train_metrics(tmp_path: pathlib.Path) -> None: + # Make sure that training metrics happen every 5 steps, as specified. + events = do_train( + tmp_path, + num_train_epochs=2, + evaluation_strategy="epoch", + logging_steps=5, + ) + data = utils.assert_events_match( + events, + "!report_metrics:training", + ("report_metrics:training:5", "metrics"), + "!report_metrics:training", + "report_metrics:training:10", + "!report_metrics:training", + "report_metrics:training:15", + # Trainer always logs training metrics before exiting. + "report_metrics:training:16", + "!report_metrics:training", + ) + # Check non-epoch metrics. + check_hf_metrics(data["metrics"]) + + # If logging_steps aligns with our exit batch (logging_steps == len(data)), we only log once. + events = do_train( + tmp_path, + num_train_epochs=1, + evaluation_strategy="epoch", + logging_steps=8, + ) + data = utils.assert_events_match( + events, + "!report_metrics:training", + ("report_metrics:training:8", "metrics"), + "!report_metrics:training", + ) + # Check epoch metrics. + check_hf_metrics(data["metrics"]) + + +def test_save_at_end(tmp_path: pathlib.Path) -> None: + # We force a save even if Transformers wouldn't. + events = do_train( + tmp_path, + num_train_epochs=1, + ) + utils.assert_events_match( + events, + "!report_checkpoint", + "before_save:8", + "report_checkpoint", + "after_save:8", + "!report_checkpoint", + ) + + # We can override it. Also, this tests that the previous case was valid, because it proves that + # the save that occured was the one we forced. + events = do_train( + tmp_path, + force_final_save=False, + num_train_epochs=1, + ) + utils.assert_events_match( + events, + "!report_checkpoint", + ) + + # Also, if the trainer naturally saves at that time, we don't duplicate the save. + events = do_train( + tmp_path, + # force_final_save=False, + num_train_epochs=1, + save_steps=8, + ) + utils.assert_events_match( + events, + "!report_checkpoint", + "before_save:8", + "report_checkpoint", + "after_save:8", + "!report_checkpoint", + ) + + # Same thing, but force_final_save=False to guarantee that the above test is valid (i.e. the + # save originated with Transformers). + events = do_train( + tmp_path, + force_final_save=False, + num_train_epochs=1, + save_steps=8, + ) + utils.assert_events_match( + events, + "!report_checkpoint", + "before_save:8", + "report_checkpoint", + "after_save:8", + "!report_checkpoint", + ) + + # Save a final checkpoint if we are preempted. + events = do_train( + tmp_path, + set_preempt_on_event="report_metrics:training:3", + logging_steps=1, + num_train_epochs=1, + ) + utils.assert_events_match( + events, + "!report_checkpoint", + "before_save:3", + "report_checkpoint", + "after_save:3", + "!report_checkpoint", + ) + + +def test_eval(tmp_path: pathlib.Path) -> None: + # Eval on epoch boundaries. + # (This test also ensures we don't double-evaluate with our evaluate-at-end logic). + events = do_train( + tmp_path, + num_train_epochs=2, + evaluation_strategy="epoch", + logging_steps=5, + ) + data = utils.assert_events_match( + events, + "!report_metrics:validation", + "!on_evaluate", + ("report_metrics:validation:8", "metrics"), + "on_evaluate:8", + "!report_metrics:validation", + "!on_evaluate", + "report_metrics:validation:16", + "on_evaluate:16", + "!report_metrics:validation", + "!on_evaluate", + ) + # Check epoch metrics. + check_hf_metrics(data["metrics"]) + + # Eval off epoch boundaries, and once at the end. + events = do_train( + tmp_path, + num_train_epochs=1, + evaluation_strategy="steps", + eval_steps=5, + ) + data = utils.assert_events_match( + events, + "!report_metrics:validation", + "!on_evaluate", + ("report_metrics:validation:5", "off-epoch-metrics"), + "on_evaluate:5", + "!report_metrics:validation", + "!on_evaluate", + ("report_metrics:validation:8", "final-metrics"), + "on_evaluate:8", + "!report_metrics:validation", + "!on_evaluate", + ) + # Check non-epoch metrics, and the at-end metrics. + check_hf_metrics(data["off-epoch-metrics"]) + check_hf_metrics(data["final-metrics"]) + + # Same thing, but we can disable the evaluate-at-end. Also this proves that our evaluate-at-end + # was working in the previous case. + events = do_train( + tmp_path, + force_final_evaluate=False, + num_train_epochs=1, + evaluation_strategy="steps", + eval_steps=5, + ) + utils.assert_events_match( + events, + "!report_metrics:validation", + "!on_evaluate", + "report_metrics:validation:5", + "on_evaluate:5", + "!report_metrics:validation", + "!on_evaluate", + ) + + # Same thing, but we can disable the evaluate-at-end. Also this proves that our evaluate-at-end + # was working in the previous case. + events = do_train( + tmp_path, + force_final_evaluate=False, + num_train_epochs=1, + evaluation_strategy="steps", + eval_steps=5, + ) + utils.assert_events_match( + events, + "!report_metrics:validation", + "!on_evaluate", + "report_metrics:validation:5", + "on_evaluate:5", + "!report_metrics:validation", + "!on_evaluate", + ) + + # Never evaluate-at-end if we got preempted. + events = do_train( + tmp_path, + set_preempt_on_event="report_metrics:training:3", + num_train_epochs=1, + logging_steps=1, + evaluation_strategy="steps", + eval_steps=5, + ) + utils.assert_events_match( + events, + "!report_metrics:validation", + "!on_evaluate", + ) + + +def test_save_and_restore(tmp_path: pathlib.Path) -> None: + events = do_train( + tmp_path, + set_preempt_on_event="report_metrics:training:3", + max_steps=5, + logging_steps=1, + ) + data = utils.assert_events_match( + events, + ("after_epoch_end", "weight"), + ("report_checkpoint", "ckpt"), + ) + + # Make sure our next training continues from here. + ckpt = data["ckpt"] + ckpt_weight = data["weight"] + + # Note that model is loaded _after_ on_epoch_begin, so to know that we loaded a model we'll + # compare weight after training one batch to the checkpoint weight (which had more than one + # batch of training behind it). + events = do_train( + tmp_path, + latest_checkpoint=ckpt, + max_steps=1, + ) + data = utils.assert_events_match( + events, + # training should continue from global_step=3 + "on_train_begin:3", + ("after_epoch_end", "weight"), + ) + + # Model weight will be slowly moving from 0 to 1 throughout training. + assert data["weight"] > ckpt_weight