From 2933cd68be67e97472341f4573c955cd730df528 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Wed, 4 Sep 2024 22:17:04 -0700 Subject: [PATCH 01/37] sft recipes to eval kd --- recipes/configs/llama3_1/1B_lora_kd_sft.yaml | 95 ++++++++++++++++++++ recipes/configs/llama3_1/8B_full_kd_sft.yaml | 79 ++++++++++++++++ recipes/configs/llama3_1/8B_lora_kd_sft.yaml | 82 +++++++++++++++++ 3 files changed, 256 insertions(+) create mode 100644 recipes/configs/llama3_1/1B_lora_kd_sft.yaml create mode 100644 recipes/configs/llama3_1/8B_full_kd_sft.yaml create mode 100644 recipes/configs/llama3_1/8B_lora_kd_sft.yaml diff --git a/recipes/configs/llama3_1/1B_lora_kd_sft.yaml b/recipes/configs/llama3_1/1B_lora_kd_sft.yaml new file mode 100644 index 0000000000..5256aaff58 --- /dev/null +++ b/recipes/configs/llama3_1/1B_lora_kd_sft.yaml @@ -0,0 +1,95 @@ +# Config for single device LoRA finetuning in lora_finetune_single_device.py +# using a Llama3.1 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.lora_llama3_1 + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + vocab_size: 128256 + num_layers: 16 + num_heads: 32 + num_kv_heads: 8 + embed_dim: 2048 + max_seq_len: 131072 + intermediate_dim: 8192 + attn_dropout: 0.0 + norm_eps: 1e-5 + rope_base: 500000.0 + lora_rank: 8 + lora_alpha: 16 + lora_dropout: 0.05 + use_dora: False + quantize_base: False + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: null + +checkpointer: + _component_: torchtune.training.FullModelMetaCheckpointer + checkpoint_dir: /home/lindawang/llama/Llama-1B/ + checkpoint_files: [ + consolidated.00.pth + ] + recipe_checkpoint: null + output_dir: /home/lindawang/llama/Llama-1B/lora_finetuned/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 8 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 2 +# compile: False + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.training.metric_logging.TensorBoardLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False diff --git a/recipes/configs/llama3_1/8B_full_kd_sft.yaml b/recipes/configs/llama3_1/8B_full_kd_sft.yaml new file mode 100644 index 0000000000..ac1571552e --- /dev/null +++ b/recipes/configs/llama3_1/8B_full_kd_sft.yaml @@ -0,0 +1,79 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Llama3.1 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config llama3_1/8B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config llama3_1/8B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: null + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.llama3_1_8b + +checkpointer: + _component_: torchtune.training.FullModelMetaCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/original/ + checkpoint_files: [ + consolidated.00.pth + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/full_finetuned/ + model_type: LLAMA3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 8 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +optimizer_in_bwd: True +compile: False + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.training.metric_logging.TensorBoardLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False diff --git a/recipes/configs/llama3_1/8B_lora_kd_sft.yaml b/recipes/configs/llama3_1/8B_lora_kd_sft.yaml new file mode 100644 index 0000000000..4ee757db1b --- /dev/null +++ b/recipes/configs/llama3_1/8B_lora_kd_sft.yaml @@ -0,0 +1,82 @@ +# Config for single device LoRA finetuning in lora_finetune_single_device.py +# using a Llama3.1 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# To launch on a single device, run the following command from root: +# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.lora_llama3_1_8b + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + lora_rank: 8 + lora_alpha: 16 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: null + +checkpointer: + _component_: torchtune.training.FullModelMetaCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/original/ + checkpoint_files: [ + consolidated.00.pth + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/lora_finetuned/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 8 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 2 +# compile: False + +# Logging +output_dir: /tmp/lora_finetune_output +metric_logger: + _component_: torchtune.training.metric_logging.TensorBoardLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False From bff065a04d31b4b79e05c4884904ce13ef99e43a Mon Sep 17 00:00:00 2001 From: lindawangg Date: Thu, 5 Sep 2024 15:25:27 -0700 Subject: [PATCH 02/37] setup kd files --- recipes/configs/llama3_1/1B_lora_kd_sft.yaml | 95 --- .../llama3_1/8B_full_single_device.yaml | 106 ---- recipes/configs/llama3_1/8B_lora_kd_sft.yaml | 82 --- .../configs/llama3_1/kd_single_device.yaml | 0 recipes/kd_single_device.py | 560 ++++++++++++++++++ torchtune/_recipe_registry.py | 11 + 6 files changed, 571 insertions(+), 283 deletions(-) delete mode 100644 recipes/configs/llama3_1/1B_lora_kd_sft.yaml delete mode 100644 recipes/configs/llama3_1/8B_full_single_device.yaml delete mode 100644 recipes/configs/llama3_1/8B_lora_kd_sft.yaml create mode 100644 recipes/configs/llama3_1/kd_single_device.yaml create mode 100644 recipes/kd_single_device.py diff --git a/recipes/configs/llama3_1/1B_lora_kd_sft.yaml b/recipes/configs/llama3_1/1B_lora_kd_sft.yaml deleted file mode 100644 index 5256aaff58..0000000000 --- a/recipes/configs/llama3_1/1B_lora_kd_sft.yaml +++ /dev/null @@ -1,95 +0,0 @@ -# Config for single device LoRA finetuning in lora_finetune_single_device.py -# using a Llama3.1 8B Instruct model -# -# This config assumes that you've run the following command before launching -# this run: -# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" -# -# To launch on a single device, run the following command from root: -# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device -# -# You can add specific overrides through the command line. For example -# to override the checkpointer directory while launching training -# you can run: -# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device checkpointer.checkpoint_dir= -# -# This config works only for training on single device. - - -# Model Arguments -model: - _component_: torchtune.models.llama3_1.lora_llama3_1 - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False - apply_lora_to_output: False - vocab_size: 128256 - num_layers: 16 - num_heads: 32 - num_kv_heads: 8 - embed_dim: 2048 - max_seq_len: 131072 - intermediate_dim: 8192 - attn_dropout: 0.0 - norm_eps: 1e-5 - rope_base: 500000.0 - lora_rank: 8 - lora_alpha: 16 - lora_dropout: 0.05 - use_dora: False - quantize_base: False - -# Tokenizer -tokenizer: - _component_: torchtune.models.llama3.llama3_tokenizer - path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model - max_seq_len: null - -checkpointer: - _component_: torchtune.training.FullModelMetaCheckpointer - checkpoint_dir: /home/lindawang/llama/Llama-1B/ - checkpoint_files: [ - consolidated.00.pth - ] - recipe_checkpoint: null - output_dir: /home/lindawang/llama/Llama-1B/lora_finetuned/ - model_type: LLAMA3 -resume_from_checkpoint: False -save_adapter_weights_only: False - -# Dataset and Sampler -dataset: - _component_: torchtune.datasets.alpaca_cleaned_dataset -seed: null -shuffle: True -batch_size: 8 - -# Optimizer and Scheduler -optimizer: - _component_: torch.optim.AdamW - weight_decay: 0.01 - lr: 3e-4 -lr_scheduler: - _component_: torchtune.modules.get_cosine_schedule_with_warmup - num_warmup_steps: 100 - -loss: - _component_: torchtune.modules.loss.CEWithChunkedOutputLoss - -# Training -epochs: 1 -max_steps_per_epoch: null -gradient_accumulation_steps: 2 -# compile: False - -# Logging -output_dir: /tmp/lora_finetune_output -metric_logger: - _component_: torchtune.training.metric_logging.TensorBoardLogger - log_dir: ${output_dir} -log_every_n_steps: 1 -log_peak_memory_stats: False - -# Environment -device: cuda -dtype: bf16 -enable_activation_checkpointing: False diff --git a/recipes/configs/llama3_1/8B_full_single_device.yaml b/recipes/configs/llama3_1/8B_full_single_device.yaml deleted file mode 100644 index 754c5e9fa4..0000000000 --- a/recipes/configs/llama3_1/8B_full_single_device.yaml +++ /dev/null @@ -1,106 +0,0 @@ -# Config for single device full finetuning in full_finetune_single_device.py -# using a Llama3.1 8B Instruct model -# -# This config assumes that you've run the following command before launching -# this run: -# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" -# -# The default config uses an optimizer from bitsandbytes. If you do not have it installed, -# you can install it with -# pip install bitsandbytes -# -# To launch on a single device, run the following command from root: -# tune run full_finetune_single_device --config llama3_1/8B_full_single_device -# -# You can add specific overrides through the command line. For example -# to override the checkpointer directory while launching training -# you can run: -# tune run full_finetune_single_device --config llama3_1/8B_full_single_device checkpointer.checkpoint_dir= -# -# This config works only for training on single device. - - -# Tokenizer -tokenizer: - _component_: torchtune.models.llama3.llama3_tokenizer - path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model - max_seq_len: null - -# Dataset -dataset: - _component_: torchtune.datasets.alpaca_dataset -seed: null -shuffle: True - -# Model Arguments -model: - _component_: torchtune.models.llama3_1.llama3_1_8b - -checkpointer: - _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ - checkpoint_files: [ - model-00001-of-00004.safetensors, - model-00002-of-00004.safetensors, - model-00003-of-00004.safetensors, - model-00004-of-00004.safetensors - ] - recipe_checkpoint: null - output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ - model_type: LLAMA3 -resume_from_checkpoint: False - -# Fine-tuning arguments -batch_size: 2 -epochs: 3 -optimizer: - _component_: bitsandbytes.optim.PagedAdamW8bit - lr: 2e-5 -loss: - _component_: torchtune.modules.loss.CEWithChunkedOutputLoss -max_steps_per_epoch: null -gradient_accumulation_steps: 1 -optimizer_in_bwd: True -compile: False - -# Training environment -device: cuda - -# Memory management -enable_activation_checkpointing: True - -# Reduced precision -dtype: bf16 - -# Logging -metric_logger: - _component_: torchtune.training.metric_logging.DiskLogger - log_dir: ${output_dir} -output_dir: /tmp/full-llama3.1-finetune -log_every_n_steps: 1 -log_peak_memory_stats: False - -# Profiler (disabled) -profiler: - _component_: torchtune.training.setup_torch_profiler - enabled: False - - #Output directory of trace artifacts - output_dir: ${output_dir}/profiling_outputs - - #`torch.profiler.ProfilerActivity` types to trace - cpu: True - cuda: True - - #trace options passed to `torch.profiler.profile` - profile_memory: True - with_stack: False - record_shapes: True - with_flops: False - - # `torch.profiler.schedule` options: - # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat - wait_steps: 1 - warmup_steps: 2 - active_steps: 1 - num_cycles: 1 diff --git a/recipes/configs/llama3_1/8B_lora_kd_sft.yaml b/recipes/configs/llama3_1/8B_lora_kd_sft.yaml deleted file mode 100644 index 4ee757db1b..0000000000 --- a/recipes/configs/llama3_1/8B_lora_kd_sft.yaml +++ /dev/null @@ -1,82 +0,0 @@ -# Config for single device LoRA finetuning in lora_finetune_single_device.py -# using a Llama3.1 8B Instruct model -# -# This config assumes that you've run the following command before launching -# this run: -# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" -# -# To launch on a single device, run the following command from root: -# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device -# -# You can add specific overrides through the command line. For example -# to override the checkpointer directory while launching training -# you can run: -# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device checkpointer.checkpoint_dir= -# -# This config works only for training on single device. - - -# Model Arguments -model: - _component_: torchtune.models.llama3_1.lora_llama3_1_8b - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False - apply_lora_to_output: False - lora_rank: 8 - lora_alpha: 16 - -# Tokenizer -tokenizer: - _component_: torchtune.models.llama3.llama3_tokenizer - path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model - max_seq_len: null - -checkpointer: - _component_: torchtune.training.FullModelMetaCheckpointer - checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/original/ - checkpoint_files: [ - consolidated.00.pth - ] - recipe_checkpoint: null - output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/lora_finetuned/ - model_type: LLAMA3 -resume_from_checkpoint: False -save_adapter_weights_only: False - -# Dataset and Sampler -dataset: - _component_: torchtune.datasets.alpaca_cleaned_dataset -seed: null -shuffle: True -batch_size: 8 - -# Optimizer and Scheduler -optimizer: - _component_: torch.optim.AdamW - weight_decay: 0.01 - lr: 3e-4 -lr_scheduler: - _component_: torchtune.modules.get_cosine_schedule_with_warmup - num_warmup_steps: 100 - -loss: - _component_: torchtune.modules.loss.CEWithChunkedOutputLoss - -# Training -epochs: 1 -max_steps_per_epoch: null -gradient_accumulation_steps: 2 -# compile: False - -# Logging -output_dir: /tmp/lora_finetune_output -metric_logger: - _component_: torchtune.training.metric_logging.TensorBoardLogger - log_dir: ${output_dir} -log_every_n_steps: 1 -log_peak_memory_stats: False - -# Environment -device: cuda -dtype: bf16 -enable_activation_checkpointing: False diff --git a/recipes/configs/llama3_1/kd_single_device.yaml b/recipes/configs/llama3_1/kd_single_device.yaml new file mode 100644 index 0000000000..e69de29bb2 diff --git a/recipes/kd_single_device.py b/recipes/kd_single_device.py new file mode 100644 index 0000000000..aa6a21574b --- /dev/null +++ b/recipes/kd_single_device.py @@ -0,0 +1,560 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys +import time + +from functools import partial +from typing import Any, Dict, Optional, Tuple + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, training, utils +from torchtune.data import padded_collate +from torchtune.datasets import ConcatDataset +from torchtune.modules.peft import ( + get_adapter_params, + get_lora_module_names, + get_merged_lora_ckpt, + load_dora_magnitudes, + set_trainable_params, + validate_missing_and_unexpected_for_lora, +) +from torchtune.recipe_interfaces import FTRecipeInterface + + +class KDRecipeSingleDevice(FTRecipeInterface): + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + # Reduced precision logic + self._dtype = training.get_dtype(cfg.dtype, device=self._device) + # fp16 precision is explicitly disabled as it is not supported in this + # recipe (for example, no gradient scaling). + if self._dtype == torch.float16: + raise ValueError( + "fp16 precision is not supported in this recipe. Please use fp32 or bf16." + ) + # For CUDA devices, check if the HW supports bf16 if bf16 is specified. + if ( + self._dtype == torch.bfloat16 + and self._device != torch.device("cpu") + and not torch.cuda.is_bf16_supported() + ): + raise RuntimeError("Full bf16 training is not supported on this hardware.") + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = utils.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._clip_grad_norm = cfg.get("clip_grad_norm", None) + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. This includes the + base model weights. If resume_from_checkpoint is True, this also includes + the adapter weights and recipe state + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + # TODO: Add after KD recipe is implemented + raise RuntimeError("Resume from checkpoint is not supported yet.") + return checkpoint_dict + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. + """ + # TODO: Add teacher model setup and KD loss + + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + self._model_compile = cfg.compile + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + + # set up model + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + compile_model=cfg.compile, + base_model_state_dict=checkpoint_dict[training.MODEL_KEY], + lora_weights_state_dict=( + checkpoint_dict[training.ADAPTER_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + self._tokenizer = config.instantiate(cfg.tokenizer) + log.info("Tokenizer is initialized from file.") + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + # initialize loss + self._loss_fn = config.instantiate(cfg.loss) + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + if self._model_compile: + log.info("Compiling loss with torch.compile...") + # For CEWithChunkedOutputLoss, if we compile the entire class + # we lose the benefits from the chunked loss. + # Therefore, we only compile the cross entropy function + upcasting + self._loss_fn.compute_cross_entropy = torch.compile( + self._loss_fn.compute_cross_entropy, backend=backend + ) + else: + if self._model_compile: + log.info("Compiling loss with torch.compile...") + self._loss_fn = torch.compile(self._loss_fn, backend=backend) + log.info("Loss is initialized.") + + # Dataloader depends on the tokenizer and loss_fn and should be + # setup after all of these are setup + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader and the max_steps_per_epoch param set by the user and is used + # for logging and tracking training state. This should be computed after the dataloader + # has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Learning rate scheduler can only be set up after number of steps + # has been computed + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.lr_scheduler, + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + + # TODO: add after KD recipe is implemented + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False + # self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + compile_model: bool, + base_model_state_dict: Dict[str, Any], + lora_weights_state_dict: Optional[Dict[str, Any]] = None, + ) -> nn.Module: + with training.set_default_dtype(self._dtype), self._device: + model = config.instantiate(cfg_model) + + self._lora_rank = cfg_model.lora_rank + self._lora_alpha = cfg_model.lora_alpha + self._lora_attn_modules = list(cfg_model.lora_attn_modules) + self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp + self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) + self.adapter_params = get_adapter_params(model) + self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()]) + set_trainable_params(model, self.adapter_params) + + if compile_model: + log.info("Compiling model layers with torch.compile...") + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + for m in reversed(list(model.modules())): + if isinstance(m, modules.transformer.TransformerSelfAttentionLayer): + m.compile(backend=backend) + + if enable_activation_checkpointing: + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + + base_missing, base_unexpected = model.load_state_dict( + base_model_state_dict, strict=False + ) + # This is for any adapters that need to be initialized after base weights + # have been loaded (e.g. DoRA). + if self._is_dora: + load_dora_magnitudes(model) + if lora_weights_state_dict: + lora_missing, lora_unexpected = model.load_state_dict( + lora_weights_state_dict, strict=False + ) + else: + lora_missing, lora_unexpected = None, None + validate_missing_and_unexpected_for_lora( + lora_attn_modules=self._lora_attn_modules, + apply_lora_to_mlp=self._apply_lora_to_mlp, + apply_lora_to_output=self._apply_lora_to_output, + base_missing=base_missing, + base_unexpected=base_unexpected, + lora_missing=lora_missing, + lora_unexpected=lora_unexpected, + ) + # Validate model adapter params were loaded in with the expected dtype + # TODO (rohan-varma): Further validation to ensure the appropriate base params + # are NF4 vs bf16 based on the quantization config. + training.validate_expected_param_dtype( + self.adapter_params.items(), dtype=self._dtype + ) + + log.info(f"Model is initialized with precision {self._dtype}.") + + if self._device.type == "cuda": + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) + return model + + def _setup_optimizer( + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + ) -> Optimizer: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + optimizer.load_state_dict(opt_state_dict) + + log.info("Optimizer and loss are initialized.") + return optimizer + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: DictConfig, + num_training_steps: int, + last_epoch: int, + ) -> Optimizer: + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + self._optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + + log.info("Learning rate scheduler is initialized.") + return lr_scheduler + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports + Map-style Datasets which fit into memory and an option for random shuffling. + Samplers, iterable datasets, and streaming datasets are not supported. + """ + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = False + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + packed = cfg_dataset.get("packed", False) + + sampler = DistributedSampler( + ds, + num_replicas=1, + rank=0, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + sampler=sampler, + batch_size=batch_size, + collate_fn=( + partial( + padded_collate, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else None + ), + ) + + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint(self, epoch: int) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Merged weights with key MODEL_KEY + - Adapter weights with key ADAPTER_KEY + - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights + + To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. + """ + ckpt_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + # if training is in-progress, checkpoint the optimizer state as well + if intermediate_checkpoint: + ckpt_dict.update( + { + training.OPT_KEY: self._optimizer.state_dict(), + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + # Move to CPU to avoid a copy on GPU + state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} + + # Construct the full state dict with LoRA weights merged into base LLM weights + merged_state_dict = get_merged_lora_ckpt( + state_dict, + rank=self._lora_rank, + alpha=self._lora_alpha, + ) + ckpt_dict.update({training.MODEL_KEY: merged_state_dict}) + + # Construct the adapter weights + adapter_key_filter = lambda x: x in self.adapter_params + adapter_state_dict = { + k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k) + } + ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) + adapter_config = { + "r": self._lora_rank, + "lora_alpha": self._lora_alpha, + "target_modules": get_lora_module_names( + self._lora_attn_modules, + self._apply_lora_to_mlp, + self._apply_lora_to_output, + ), + "peft_type": "LORA", + } + ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config}) + + self._checkpointer.save_checkpoint( + ckpt_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, + ) + + def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + # TODO: get logits from teacher and compute KD loss + + # Both are shape [b, s] + tokens, labels = batch["tokens"], batch["labels"] + + # Get the attention mask and position ids from the dataset if they + # exist. Currently, only sample packing in PackedDataset returns these + mask = batch.get("mask", None) # shape [b, s, s] + input_pos = batch.get("input_pos", None) # shape [b, s] + + # run model + logits = self._model(tokens, mask=mask, input_pos=input_pos) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + + # Compute loss + loss = self._loss_fn(logits, labels) + + # free logits otherwise it peaks backward memory + del logits + + # TODO: return class and KD loss + return loss + + def train(self) -> None: + """ + The core training loop. + """ + + if self._model_compile: + log.info( + "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." + ) + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + with self._profiler as prof: + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() + + batch = {k: v.to(self._device) for k, v in batch.items()} + num_tokens += batch["tokens"].numel() + + # TODO: compute total loss and log losses + loss = self._loss_step(batch) + loss = loss / self._gradient_accumulation_steps + running_loss += loss + loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + if self._clip_grad_norm is not None: + grad_norm = torch.nn.utils.clip_grad_norm_( + self._model.parameters(), + max_norm=float(self._clip_grad_norm), + ) + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self._lr_scheduler.step() + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if self.global_step % self._log_every_n_steps == 0: + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens / time_per_step, + } + if ( + self._device.type == "cuda" + and self._log_peak_memory_stats + ): + log_dict.update( + utils.get_memory_stats(device=self._device) + ) + if self._clip_grad_norm is not None: + log_dict.update({"grad_norm": grad_norm}) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Stop tracking CUDA memory now that active steps are complete + if ( + curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step the profiler + # Note we are stepping each batch, which might not include optimizer step in the trace + # if the schedule cycle doesn't align with gradient accumulation. + prof.step() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + def cleanup(self) -> None: + self._metric_logger.close() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + config.log_config(recipe_name="KDRecipeSingleDevice", cfg=cfg) + recipe = KDRecipeSingleDevice(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 86bf9829e8..9821d35ab6 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -287,6 +287,17 @@ class Recipe: ], supports_distributed=True, ), + Recipe( + name="kd_single_device", + file_path="kd_single_device.py", + configs=[ + Config( + name="llama3_1/kd_single_device", + file_path="llama3_1/kd_single_device.yaml", + ), + ], + supports_distributed=False, + ), ] From 9dd7b473192481069cf48fe8ec9c1cba54266caf Mon Sep 17 00:00:00 2001 From: lindawangg Date: Thu, 5 Sep 2024 15:28:04 -0700 Subject: [PATCH 03/37] delete test config --- recipes/configs/llama3_1/8B_full_kd_sft.yaml | 79 -------------------- 1 file changed, 79 deletions(-) delete mode 100644 recipes/configs/llama3_1/8B_full_kd_sft.yaml diff --git a/recipes/configs/llama3_1/8B_full_kd_sft.yaml b/recipes/configs/llama3_1/8B_full_kd_sft.yaml deleted file mode 100644 index ac1571552e..0000000000 --- a/recipes/configs/llama3_1/8B_full_kd_sft.yaml +++ /dev/null @@ -1,79 +0,0 @@ -# Config for single device full finetuning in full_finetune_single_device.py -# using a Llama3.1 8B Instruct model -# -# This config assumes that you've run the following command before launching -# this run: -# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" -# -# The default config uses an optimizer from bitsandbytes. If you do not have it installed, -# you can install it with -# pip install bitsandbytes -# -# To launch on a single device, run the following command from root: -# tune run full_finetune_single_device --config llama3_1/8B_full_single_device -# -# You can add specific overrides through the command line. For example -# to override the checkpointer directory while launching training -# you can run: -# tune run full_finetune_single_device --config llama3_1/8B_full_single_device checkpointer.checkpoint_dir= -# -# This config works only for training on single device. - - -# Tokenizer -tokenizer: - _component_: torchtune.models.llama3.llama3_tokenizer - path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model - max_seq_len: null - -# Dataset -dataset: - _component_: torchtune.datasets.alpaca_dataset -seed: null -shuffle: True - -# Model Arguments -model: - _component_: torchtune.models.llama3_1.llama3_1_8b - -checkpointer: - _component_: torchtune.training.FullModelMetaCheckpointer - checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/original/ - checkpoint_files: [ - consolidated.00.pth - ] - recipe_checkpoint: null - output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/full_finetuned/ - model_type: LLAMA3 -resume_from_checkpoint: False - -# Fine-tuning arguments -batch_size: 8 -epochs: 1 -optimizer: - _component_: torch.optim.AdamW - lr: 2e-5 - -loss: - _component_: torchtune.modules.loss.CEWithChunkedOutputLoss -max_steps_per_epoch: null -gradient_accumulation_steps: 1 -optimizer_in_bwd: True -compile: False - -# Training environment -device: cuda - -# Memory management -enable_activation_checkpointing: True - -# Reduced precision -dtype: bf16 - -# Logging -output_dir: /tmp/lora_finetune_output -metric_logger: - _component_: torchtune.training.metric_logging.TensorBoardLogger - log_dir: ${output_dir} -log_every_n_steps: 1 -log_peak_memory_stats: False From a39e99cbfae852e5411847291855fe04b9126f23 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Thu, 5 Sep 2024 22:22:46 -0700 Subject: [PATCH 04/37] added student config --- .../configs/llama3_1/kd_single_device.yaml | 93 +++++++++++++++++++ recipes/kd_single_device.py | 77 ++++++++++++++- 2 files changed, 167 insertions(+), 3 deletions(-) diff --git a/recipes/configs/llama3_1/kd_single_device.yaml b/recipes/configs/llama3_1/kd_single_device.yaml index e69de29bb2..2f762449e0 100644 --- a/recipes/configs/llama3_1/kd_single_device.yaml +++ b/recipes/configs/llama3_1/kd_single_device.yaml @@ -0,0 +1,93 @@ +# Config for single device knowledge distillation in kd_single_device.py +# using a teacher and student model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct +# +# To launch on a single device, run the following command from root: +# tune run kd_single_device --config llama3_1/kd_single_device +# +# This config works only for distilling on a single device. + + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.lora_llama3_1 + lora_attn_modules: ['q_proj', 'v_proj'] + apply_lora_to_mlp: False + apply_lora_to_output: False + vocab_size: 128256 + num_layers: 16 + num_heads: 32 + num_kv_heads: 8 + embed_dim: 2048 + max_seq_len: 131072 + intermediate_dim: 8192 + attn_dropout: 0.0 + norm_eps: 1e-5 + rope_base: 500000.0 + lora_rank: 8 + lora_alpha: 16 + lora_dropout: 0.05 + use_dora: False + quantize_base: False + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: null + +checkpointer: + _component_: torchtune.training.FullModelMetaCheckpointer + checkpoint_dir: /tmp/Llama-3.1-Student/ + checkpoint_files: [ + consolidated.00.pth + ] + recipe_checkpoint: null + output_dir: /tmp/kd_model/ + model_type: LLAMA3 + +# TODO: add teacher checkpointer + +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 8 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 16 +compile: False + +# Logging +output_dir: /tmp/kd_output +metric_logger: + _component_: torchtune.training.metric_logging.TensorBoardLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False diff --git a/recipes/kd_single_device.py b/recipes/kd_single_device.py index aa6a21574b..ec857120ec 100644 --- a/recipes/kd_single_device.py +++ b/recipes/kd_single_device.py @@ -9,7 +9,7 @@ import time from functools import partial -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import torch from omegaconf import DictConfig, ListConfig @@ -29,6 +29,11 @@ validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.training import DummyProfiler, PROFILER_KEY + +from tqdm import tqdm + +log = utils.get_logger("DEBUG") class KDRecipeSingleDevice(FTRecipeInterface): @@ -176,16 +181,82 @@ def setup(self, cfg: DictConfig) -> None: last_epoch=self.global_step - 1, ) - # TODO: add after KD recipe is implemented # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) # if cfg is missing profiler key or if `cfg.profiler.enabled = False - # self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) # Used to ignore labels for loss computation self.ignore_labels_cache = torch.full( (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device ) + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + def _setup_model( self, cfg_model: DictConfig, From 0c4e4f9514a57335863813a331bad628b3ff46e2 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Fri, 6 Sep 2024 16:06:43 -0700 Subject: [PATCH 05/37] added teacher model loading --- .../configs/llama3_1/kd_single_device.yaml | 12 ++++++ recipes/kd_single_device.py | 42 ++++++++++++++++--- 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/recipes/configs/llama3_1/kd_single_device.yaml b/recipes/configs/llama3_1/kd_single_device.yaml index 2f762449e0..1f5ee63c4a 100644 --- a/recipes/configs/llama3_1/kd_single_device.yaml +++ b/recipes/configs/llama3_1/kd_single_device.yaml @@ -33,6 +33,9 @@ model: use_dora: False quantize_base: False +teacher_model: + _component_: torchtune.models.llama3_1.llama3_1_8b + # Tokenizer tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer @@ -50,6 +53,15 @@ checkpointer: model_type: LLAMA3 # TODO: add teacher checkpointer +teacher_checkpointer: + _component_: torchtune.training.FullModelMetaCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/original/ + checkpoint_files: [ + consolidated.00.pth + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + model_type: LLAMA3 resume_from_checkpoint: False save_adapter_weights_only: False diff --git a/recipes/kd_single_device.py b/recipes/kd_single_device.py index ec857120ec..1e22082a7a 100644 --- a/recipes/kd_single_device.py +++ b/recipes/kd_single_device.py @@ -18,7 +18,7 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler from torchtune import config, modules, training, utils -from torchtune.data import padded_collate +from torchtune.data import padded_collate_sft from torchtune.datasets import ConcatDataset from torchtune.modules.peft import ( get_adapter_params, @@ -61,7 +61,7 @@ def __init__(self, cfg: DictConfig) -> None: # These are public properties which are updated by the checkpoint loader # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = utils.set_seed(seed=cfg.seed) + self.seed = training.set_seed(seed=cfg.seed) self.epochs_run = 0 self.total_epochs = cfg.epochs self.max_steps_per_epoch = cfg.max_steps_per_epoch @@ -102,6 +102,9 @@ def setup(self, cfg: DictConfig) -> None: self._model_compile = cfg.compile checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + teacher_checkpoint_dict = self.load_checkpoint( + cfg_checkpointer=cfg.teacher_checkpointer + ) # set up model self._model = self._setup_model( @@ -116,6 +119,11 @@ def setup(self, cfg: DictConfig) -> None: ), ) + self._teacher_model = self._setup_teacher_model( + model_cfg=cfg.teacher_model, + model_state_dict=teacher_checkpoint_dict[training.MODEL_KEY], + ) + self._tokenizer = config.instantiate(cfg.tokenizer) log.info("Tokenizer is initialized from file.") @@ -321,8 +329,30 @@ def _setup_model( log.info(f"Model is initialized with precision {self._dtype}.") if self._device.type == "cuda": - memory_stats = utils.get_memory_stats(device=self._device) - utils.log_memory_stats(memory_stats) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + return model + + def _setup_teacher_model( + self, + model_cfg: DictConfig, + model_state_dict: Dict[str, Any], + ) -> nn.Module: + with training.set_default_dtype(self._dtype), self._device: + model = config.instantiate(model_cfg) + + model.load_state_dict(model_state_dict) + + # Put model in eval mode. + # Note: This will not disable the dropout applied in SDPA, + # see https://github.com/pytorch/pytorch/issues/124464 + model.eval() + + # Validate model was loaded in with the expected dtype. + training.validate_expected_param_dtype( + model.named_parameters(), dtype=self._dtype + ) + log.info(f"Teacher model is initialized with precision {self._dtype}.") return model def _setup_optimizer( @@ -386,7 +416,7 @@ def _setup_data( batch_size=batch_size, collate_fn=( partial( - padded_collate, + padded_collate_sft, padding_idx=self._tokenizer.pad_id, ignore_idx=self._loss_fn.ignore_index, ) @@ -574,7 +604,7 @@ def train(self) -> None: and self._log_peak_memory_stats ): log_dict.update( - utils.get_memory_stats(device=self._device) + training.get_memory_stats(device=self._device) ) if self._clip_grad_norm is not None: log_dict.update({"grad_norm": grad_norm}) From 380f267f1c28ee890d901c967a13746e89c95078 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Fri, 6 Sep 2024 18:05:30 -0700 Subject: [PATCH 06/37] added loss --- .../configs/llama3_1/kd_single_device.yaml | 2 +- recipes/kd_single_device.py | 44 ++++++++++++++++--- 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/recipes/configs/llama3_1/kd_single_device.yaml b/recipes/configs/llama3_1/kd_single_device.yaml index 1f5ee63c4a..a2433932e1 100644 --- a/recipes/configs/llama3_1/kd_single_device.yaml +++ b/recipes/configs/llama3_1/kd_single_device.yaml @@ -71,7 +71,7 @@ dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset seed: null shuffle: True -batch_size: 8 +batch_size: 4 # Optimizer and Scheduler optimizer: diff --git a/recipes/kd_single_device.py b/recipes/kd_single_device.py index 1e22082a7a..7866a78d02 100644 --- a/recipes/kd_single_device.py +++ b/recipes/kd_single_device.py @@ -93,7 +93,6 @@ def setup(self, cfg: DictConfig) -> None: Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. """ - # TODO: Add teacher model setup and KD loss self._metric_logger = config.instantiate(cfg.metric_logger) @@ -142,6 +141,7 @@ def setup(self, cfg: DictConfig) -> None: if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": # set num_output_chunks for model self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + self._teacher_model.set_num_output_chunks(self._loss_fn.num_output_chunks) if self._model_compile: log.info("Compiling loss with torch.compile...") # For CEWithChunkedOutputLoss, if we compile the entire class @@ -491,8 +491,9 @@ def save_checkpoint(self, epoch: int) -> None: adapter_only=self._save_adapter_weights_only, ) - def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - # TODO: get logits from teacher and compute KD loss + def _loss_step( + self, batch: Dict[str, torch.Tensor] + ) -> (torch.Tensor, torch.Tensor): # Both are shape [b, s] tokens, labels = batch["tokens"], batch["labels"] @@ -515,14 +516,43 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: labels = labels.reshape(-1) logits = logits.reshape(-1, logits.size(-1)) + # Compute KD loss + teacher_logits = self._teacher_model(tokens, mask=mask, input_pos=input_pos) + # reshape logits to [bsz, s, v] + teacher_logits = [ + logit_chunk.reshape(-1, logit_chunk.size(-1)) + for logit_chunk in teacher_logits + ] + student_logits = [ + logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits + ] + distill_labels = [ + target_chunk.reshape(-1) + for target_chunk in labels.chunk(self._loss_fn.num_output_chunks, dim=1) + ] + total_distill_loss = 0.0 + for teacher_chunk, student_chunk, label_chunk in zip( + teacher_logits, student_logits, distill_labels + ): + teacher_prob = torch.nn.functional.softmax(teacher_chunk, dim=-1) + inf_mask = torch.isinf(student_chunk) + student_logprob = torch.nn.functional.log_softmax(student_chunk, dim=-1) + prod_probs = torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0) + x = torch.sum(prod_probs, dim=-1).view(-1) + mask = (label_chunk != -100).int() + total_distill_loss += -torch.sum(x * mask.view(-1), dim=0) / torch.sum( + mask.view(-1), dim=0 + ) + # Compute loss loss = self._loss_fn(logits, labels) # free logits otherwise it peaks backward memory del logits + del teacher_logits + del student_logits - # TODO: return class and KD loss - return loss + return loss, total_distill_loss def train(self) -> None: """ @@ -566,8 +596,8 @@ def train(self) -> None: batch = {k: v.to(self._device) for k, v in batch.items()} num_tokens += batch["tokens"].numel() - # TODO: compute total loss and log losses - loss = self._loss_step(batch) + class_loss, kd_loss = self._loss_step(batch) + loss = 0.5 * class_loss + 0.5 * kd_loss loss = loss / self._gradient_accumulation_steps running_loss += loss loss.backward() From da2b4bbeea6a97a6371fab74fd44ddb3c10a53a0 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Tue, 10 Sep 2024 09:55:30 -0700 Subject: [PATCH 07/37] kd initial experiment config --- .../configs/llama3_1/kd_single_device.yaml | 4 +- recipes/kd_single_device.py | 43 ++++++++++++++++--- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/recipes/configs/llama3_1/kd_single_device.yaml b/recipes/configs/llama3_1/kd_single_device.yaml index a2433932e1..4c6fc489be 100644 --- a/recipes/configs/llama3_1/kd_single_device.yaml +++ b/recipes/configs/llama3_1/kd_single_device.yaml @@ -55,9 +55,9 @@ checkpointer: # TODO: add teacher checkpointer teacher_checkpointer: _component_: torchtune.training.FullModelMetaCheckpointer - checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/original/ + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/lora_finetuned_single_device_epoch_1/ checkpoint_files: [ - consolidated.00.pth + meta_model_0.pt ] recipe_checkpoint: null output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ diff --git a/recipes/kd_single_device.py b/recipes/kd_single_device.py index 7866a78d02..cc8c2c841d 100644 --- a/recipes/kd_single_device.py +++ b/recipes/kd_single_device.py @@ -88,6 +88,23 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: raise RuntimeError("Resume from checkpoint is not supported yet.") return checkpoint_dict + def load_teacher_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. This includes the + base model weights. If resume_from_checkpoint is True, this also includes + the adapter weights and recipe state + """ + teacher_checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = teacher_checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + # TODO: Add after KD recipe is implemented + raise RuntimeError("Resume from checkpoint is not supported yet.") + return checkpoint_dict + def setup(self, cfg: DictConfig) -> None: """ Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), @@ -101,7 +118,7 @@ def setup(self, cfg: DictConfig) -> None: self._model_compile = cfg.compile checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) - teacher_checkpoint_dict = self.load_checkpoint( + teacher_checkpoint_dict = self.load_teacher_checkpoint( cfg_checkpointer=cfg.teacher_checkpointer ) @@ -517,12 +534,13 @@ def _loss_step( logits = logits.reshape(-1, logits.size(-1)) # Compute KD loss - teacher_logits = self._teacher_model(tokens, mask=mask, input_pos=input_pos) - # reshape logits to [bsz, s, v] - teacher_logits = [ - logit_chunk.reshape(-1, logit_chunk.size(-1)) - for logit_chunk in teacher_logits - ] + with torch.no_grad(): + teacher_logits = self._teacher_model(tokens, mask=mask, input_pos=input_pos) + # reshape logits to [bsz, s, v] + teacher_logits = [ + logit_chunk.reshape(-1, logit_chunk.size(-1)) + for logit_chunk in teacher_logits + ] student_logits = [ logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits ] @@ -551,6 +569,9 @@ def _loss_step( del logits del teacher_logits del student_logits + del teacher_prob + del student_logprob + del prod_probs return loss, total_distill_loss @@ -567,6 +588,8 @@ def train(self) -> None: # Initialize tokens count and running loss (for grad accumulation) t0 = time.perf_counter() running_loss = 0 + running_class_loss = 0 + running_kd_loss = 0 num_tokens = 0 with self._profiler as prof: @@ -600,6 +623,8 @@ def train(self) -> None: loss = 0.5 * class_loss + 0.5 * kd_loss loss = loss / self._gradient_accumulation_steps running_loss += loss + running_class_loss += class_loss / self._gradient_accumulation_steps + running_kd_loss += kd_loss / self._gradient_accumulation_steps loss.backward() # Step with optimizer @@ -626,6 +651,8 @@ def train(self) -> None: time_per_step = time.perf_counter() - t0 log_dict = { "loss": loss_to_log, + "class_loss": running_class_loss.item(), + "kd_loss": running_kd_loss.item(), "lr": self._optimizer.param_groups[0]["lr"], "tokens_per_second_per_gpu": num_tokens / time_per_step, } @@ -645,6 +672,8 @@ def train(self) -> None: # Reset running stats for the next step running_loss = 0 + running_class_loss = 0 + running_kd_loss = 0 num_tokens = 0 t0 = time.perf_counter() From b54929a1c172c75c4f09c5296a9da51438fe7d65 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Tue, 10 Sep 2024 17:34:18 -0700 Subject: [PATCH 08/37] separated out loss func and added test --- .../configs/llama3_1/kd_single_device.yaml | 3 + recipes/kd_single_device.py | 65 +++++++++-------- .../torchtune/modules/loss/test_kd_losses.py | 56 ++++++++++++++ torchtune/modules/loss/__init__.py | 7 +- torchtune/modules/loss/kd_losses.py | 73 +++++++++++++++++++ 5 files changed, 174 insertions(+), 30 deletions(-) create mode 100644 tests/torchtune/modules/loss/test_kd_losses.py create mode 100644 torchtune/modules/loss/kd_losses.py diff --git a/recipes/configs/llama3_1/kd_single_device.yaml b/recipes/configs/llama3_1/kd_single_device.yaml index 4c6fc489be..d15954b53b 100644 --- a/recipes/configs/llama3_1/kd_single_device.yaml +++ b/recipes/configs/llama3_1/kd_single_device.yaml @@ -85,6 +85,9 @@ lr_scheduler: loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +kd_loss: + _component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss + # Training epochs: 1 max_steps_per_epoch: null diff --git a/recipes/kd_single_device.py b/recipes/kd_single_device.py index cc8c2c841d..5b54de0964 100644 --- a/recipes/kd_single_device.py +++ b/recipes/kd_single_device.py @@ -154,9 +154,13 @@ def setup(self, cfg: DictConfig) -> None: # initialize loss self._loss_fn = config.instantiate(cfg.loss) + self._kd_loss_fn = config.instantiate(cfg.kd_loss) backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": # set num_output_chunks for model + assert ( + self._loss_fn.num_output_chunks == self._kd_loss_fn.num_output_chunks + ), "Number of output chunks for loss_fn and kd_loss_fn must be the same." self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) self._teacher_model.set_num_output_chunks(self._loss_fn.num_output_chunks) if self._model_compile: @@ -537,30 +541,33 @@ def _loss_step( with torch.no_grad(): teacher_logits = self._teacher_model(tokens, mask=mask, input_pos=input_pos) # reshape logits to [bsz, s, v] - teacher_logits = [ - logit_chunk.reshape(-1, logit_chunk.size(-1)) - for logit_chunk in teacher_logits - ] - student_logits = [ - logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits - ] - distill_labels = [ - target_chunk.reshape(-1) - for target_chunk in labels.chunk(self._loss_fn.num_output_chunks, dim=1) - ] - total_distill_loss = 0.0 - for teacher_chunk, student_chunk, label_chunk in zip( - teacher_logits, student_logits, distill_labels - ): - teacher_prob = torch.nn.functional.softmax(teacher_chunk, dim=-1) - inf_mask = torch.isinf(student_chunk) - student_logprob = torch.nn.functional.log_softmax(student_chunk, dim=-1) - prod_probs = torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0) - x = torch.sum(prod_probs, dim=-1).view(-1) - mask = (label_chunk != -100).int() - total_distill_loss += -torch.sum(x * mask.view(-1), dim=0) / torch.sum( - mask.view(-1), dim=0 - ) + # teacher_logits = [ + # logit_chunk.reshape(-1, logit_chunk.size(-1)) + # for logit_chunk in teacher_logits + # ] + # student_logits = [ + # logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits + # ] + # distill_labels = [ + # target_chunk.reshape(-1) + # for target_chunk in labels.chunk(self._loss_fn.num_output_chunks, dim=1) + # ] + # total_distill_loss = 0.0 + # for teacher_chunk, student_chunk, label_chunk in zip( + # teacher_logits, student_logits, distill_labels + # ): + # teacher_prob = torch.nn.functional.softmax(teacher_chunk, dim=-1) + # inf_mask = torch.isinf(student_chunk) + # student_logprob = torch.nn.functional.log_softmax(student_chunk, dim=-1) + # prod_probs = torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0) + # x = torch.sum(prod_probs, dim=-1).view(-1) + # mask = (label_chunk != -100).int() + # total_distill_loss += -torch.sum(x * mask.view(-1), dim=0) / torch.sum( + # mask.view(-1), dim=0 + # ) + + # Compute kd loss + kd_loss = self._kd_loss_fn(logits, teacher_logits, labels) # Compute loss loss = self._loss_fn(logits, labels) @@ -568,12 +575,12 @@ def _loss_step( # free logits otherwise it peaks backward memory del logits del teacher_logits - del student_logits - del teacher_prob - del student_logprob - del prod_probs + # del student_logits + # del teacher_prob + # del student_logprob + # del prod_probs - return loss, total_distill_loss + return loss, kd_loss def train(self) -> None: """ diff --git a/tests/torchtune/modules/loss/test_kd_losses.py b/tests/torchtune/modules/loss/test_kd_losses.py new file mode 100644 index 0000000000..28ee57e0a3 --- /dev/null +++ b/tests/torchtune/modules/loss/test_kd_losses.py @@ -0,0 +1,56 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import assert_expected +from torchtune.modules.loss import ForwardKLLoss, ForwardKLWithChunkedOutputLoss +from torchtune.training.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(42) + + +class TestForwardKLWithChunkedOutputLoss: + def test_forward_kl_loss(self): + # Create a sample input and label + ignore_index = -100 + batch_size = 3 + num_tokens = 50 + vocab_size = 50 + logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16) + teacher_logits = torch.randn( + batch_size, num_tokens, vocab_size, dtype=torch.bfloat16 + ) + labels = torch.randint( + 0, vocab_size, (batch_size, num_tokens), dtype=torch.long + ) + + # add random ignore index to random tokens in the label + random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens)) + labels[random_indices < num_tokens // 5] = ignore_index + + # chunked FKL + chunked_fkl_loss = ForwardKLWithChunkedOutputLoss( + num_output_chunks=8, ignore_index=ignore_index + ) + logits_chunks = logits.chunk(chunked_fkl_loss.num_output_chunks, dim=1) + teacher_logits_chunks = teacher_logits.chunk( + chunked_fkl_loss.num_output_chunks, dim=1 + ) + chunked_loss = chunked_fkl_loss(logits_chunks, teacher_logits_chunks, labels) + + # vanilla FKL + fkl_loss = ForwardKLLoss(ignore_index=ignore_index) + logits = logits.reshape(-1, logits.size(-1)) + teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1)) + labels = labels.reshape(-1) + standard_loss = fkl_loss(logits, teacher_logits, labels) + + # Assert + assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) diff --git a/torchtune/modules/loss/__init__.py b/torchtune/modules/loss/__init__.py index ed5ad0be04..385aa5389a 100644 --- a/torchtune/modules/loss/__init__.py +++ b/torchtune/modules/loss/__init__.py @@ -5,5 +5,10 @@ # LICENSE file in the root directory of this source tree. from .ce_chunked_output_loss import CEWithChunkedOutputLoss +from .kd_losses import ForwardKLLoss, ForwardKLWithChunkedOutputLoss -__all__ = ["CEWithChunkedOutputLoss"] +__all__ = [ + "CEWithChunkedOutputLoss", + "ForwardKLLoss", + "ForwardKLWithChunkedOutputLoss", +] diff --git a/torchtune/modules/loss/kd_losses.py b/torchtune/modules/loss/kd_losses.py new file mode 100644 index 0000000000..fa1053fb45 --- /dev/null +++ b/torchtune/modules/loss/kd_losses.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import torch +import torch.nn.functional as F + + +class ForwardKLLoss(torch.nn.Module): + def __init__(self, ignore_index: int = -100): + super().__init__() + self.ignore_index = ignore_index + + def forward( + self, + student_logits: torch.Tensor, + teacher_logits: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + teacher_prob = F.softmax(teacher_logits, dim=-1) + inf_mask = torch.isinf(student_logits) + student_logprob = F.log_softmax(student_logits, dim=-1) + prod_probs = torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0) + x = torch.sum(prod_probs, dim=-1).view(-1) + mask = (labels != self.ignore_index).int() + return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0) + + +class ForwardKLWithChunkedOutputLoss(torch.nn.Module): + def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100): + super().__init__() + self.num_output_chunks = num_output_chunks + self.ignore_index = ignore_index + self.fkl_loss = ForwardKLLoss(ignore_index) + + def forward( + self, + student_logits: List[torch.Tensor], + teacher_logits: List[torch.Tensor], + labels: torch.Tensor, + ) -> torch.Tensor: + # reshape logits [(bsz, num_tokens/num_chunks, vocab)] -> [(bsz*num_tokens/num_chunks, vocab)] + teacher_logits = [ + teacher_logits_chunk.reshape(-1, teacher_logits_chunk.size(-1)) + for teacher_logits_chunk in teacher_logits + ] + student_logits = [ + student_logits_chunk.reshape(-1, student_logits_chunk.size(-1)) + for student_logits_chunk in student_logits + ] + # chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)] + labels = [ + target_chunk.reshape(-1) + for target_chunk in labels.chunk(self.num_output_chunks, dim=1) + ] + total_fkl_loss = 0.0 + for student_chunk, teacher_chunk, label_chunk in zip( + student_logits, teacher_logits, labels + ): + total_fkl_loss += self.fkl_loss(student_chunk, teacher_chunk, label_chunk) + # teacher_prob = torch.nn.functional.softmax(teacher_chunk, dim=-1) + # inf_mask = torch.isinf(student_chunk) + # student_logprob = torch.nn.functional.log_softmax(student_chunk, dim=-1) + # prod_probs = torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0) + # x = torch.sum(prod_probs, dim=-1).view(-1) + # mask = (label_chunk != -100).int() + # total_fkl_loss += -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0) + + return total_fkl_loss / self.num_output_chunks From b31c56db308c3e5cdfe9fd51438d5d058fce6147 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Tue, 10 Sep 2024 22:28:44 -0700 Subject: [PATCH 09/37] added documentation --- .../configs/llama3_1/kd_single_device.yaml | 4 +- recipes/kd_single_device.py | 112 ++++++++++++------ torchtune/modules/loss/kd_losses.py | 62 ++++++++-- 3 files changed, 129 insertions(+), 49 deletions(-) diff --git a/recipes/configs/llama3_1/kd_single_device.yaml b/recipes/configs/llama3_1/kd_single_device.yaml index d15954b53b..b6eea11f57 100644 --- a/recipes/configs/llama3_1/kd_single_device.yaml +++ b/recipes/configs/llama3_1/kd_single_device.yaml @@ -52,7 +52,7 @@ checkpointer: output_dir: /tmp/kd_model/ model_type: LLAMA3 -# TODO: add teacher checkpointer +# Teacher checkpoint teacher_checkpointer: _component_: torchtune.training.FullModelMetaCheckpointer checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/lora_finetuned_single_device_epoch_1/ @@ -71,7 +71,7 @@ dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset seed: null shuffle: True -batch_size: 4 +batch_size: 8 # Optimizer and Scheduler optimizer: diff --git a/recipes/kd_single_device.py b/recipes/kd_single_device.py index 5b54de0964..82684c7514 100644 --- a/recipes/kd_single_device.py +++ b/recipes/kd_single_device.py @@ -37,6 +37,72 @@ class KDRecipeSingleDevice(FTRecipeInterface): + """ + Knowledge distillation recipe for dense transformer-based LLMs such as Llama3. This recipe is optimized + for single GPU training. Training on CPU is not supported. + + Features: + - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported.g + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * gradient accumulation steps. + + For example: with batch_size=1 and gradient_accumulation_steps=32 we get a total batch size of 32. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Lower precision optimizers. This recipe supports lower-precision optimizers from the bitsandbytes + library (https://huggingface.co/docs/bitsandbytes/main/en/index). We've tested the recipe with + 8-bit AdamW and Paged AdamW. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Currently we checkpoint both the adapter weights (trainable params only) and the + complete merged weights (adapter weights added back to the base model). For more details + please take a look at our LoRA tutorial + (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html). + + Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. Resuming + training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, + ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set + ``clip_grad_norm='inf'``. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + + """ + def __init__(self, cfg: DictConfig) -> None: self._device = utils.get_device(device=cfg.device) # Reduced precision logic @@ -84,25 +150,24 @@ def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: checkpoint_dict = self._checkpointer.load_checkpoint() if self._resume_from_checkpoint: - # TODO: Add after KD recipe is implemented - raise RuntimeError("Resume from checkpoint is not supported yet.") + if training.ADAPTER_KEY not in checkpoint_dict: + raise ValueError( + "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." + ) + # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded + # no need to check here + self._update_recipe_state(checkpoint_dict) return checkpoint_dict def load_teacher_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ - Extract the checkpoint state from file and validate. This includes the - base model weights. If resume_from_checkpoint is True, this also includes - the adapter weights and recipe state + Extract the teacher checkpoint state from file. """ teacher_checkpointer = config.instantiate( cfg_checkpointer, resume_from_checkpoint=self._resume_from_checkpoint, ) checkpoint_dict = teacher_checkpointer.load_checkpoint() - - if self._resume_from_checkpoint: - # TODO: Add after KD recipe is implemented - raise RuntimeError("Resume from checkpoint is not supported yet.") return checkpoint_dict def setup(self, cfg: DictConfig) -> None: @@ -540,31 +605,6 @@ def _loss_step( # Compute KD loss with torch.no_grad(): teacher_logits = self._teacher_model(tokens, mask=mask, input_pos=input_pos) - # reshape logits to [bsz, s, v] - # teacher_logits = [ - # logit_chunk.reshape(-1, logit_chunk.size(-1)) - # for logit_chunk in teacher_logits - # ] - # student_logits = [ - # logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits - # ] - # distill_labels = [ - # target_chunk.reshape(-1) - # for target_chunk in labels.chunk(self._loss_fn.num_output_chunks, dim=1) - # ] - # total_distill_loss = 0.0 - # for teacher_chunk, student_chunk, label_chunk in zip( - # teacher_logits, student_logits, distill_labels - # ): - # teacher_prob = torch.nn.functional.softmax(teacher_chunk, dim=-1) - # inf_mask = torch.isinf(student_chunk) - # student_logprob = torch.nn.functional.log_softmax(student_chunk, dim=-1) - # prod_probs = torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0) - # x = torch.sum(prod_probs, dim=-1).view(-1) - # mask = (label_chunk != -100).int() - # total_distill_loss += -torch.sum(x * mask.view(-1), dim=0) / torch.sum( - # mask.view(-1), dim=0 - # ) # Compute kd loss kd_loss = self._kd_loss_fn(logits, teacher_logits, labels) @@ -575,10 +615,6 @@ def _loss_step( # free logits otherwise it peaks backward memory del logits del teacher_logits - # del student_logits - # del teacher_prob - # del student_logprob - # del prod_probs return loss, kd_loss diff --git a/torchtune/modules/loss/kd_losses.py b/torchtune/modules/loss/kd_losses.py index fa1053fb45..bc62200b51 100644 --- a/torchtune/modules/loss/kd_losses.py +++ b/torchtune/modules/loss/kd_losses.py @@ -11,6 +11,11 @@ class ForwardKLLoss(torch.nn.Module): + """ + The Kullback-Leibler divergence loss for valid indexes. + Implementation of https://github.com/jongwooko/distillm/blob/master/distillm/losses.py. + """ + def __init__(self, ignore_index: int = -100): super().__init__() self.ignore_index = ignore_index @@ -21,9 +26,22 @@ def forward( teacher_logits: torch.Tensor, labels: torch.Tensor, ) -> torch.Tensor: - teacher_prob = F.softmax(teacher_logits, dim=-1) + """ + Args: + student_logits (torch.Tensor): logits from student model of shape + (batch_size*num_tokens, vocab_size). + teacher_logits (torch.Tensor): logits from teacher model of shape + (batch_size*num_tokens, vocab_size). + labels (torch.Tensor): Ground truth labels of shape + (batch_size, vocab_size). + + Returns: + torch.Tensor: KL divergence loss of shape (1,). + """ + + teacher_prob = F.softmax(teacher_logits, dim=-1, dtype=torch.float32) inf_mask = torch.isinf(student_logits) - student_logprob = F.log_softmax(student_logits, dim=-1) + student_logprob = F.log_softmax(student_logits, dim=-1, dtype=torch.float32) prod_probs = torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0) x = torch.sum(prod_probs, dim=-1).view(-1) mask = (labels != self.ignore_index).int() @@ -31,6 +49,16 @@ def forward( class ForwardKLWithChunkedOutputLoss(torch.nn.Module): + """ + Forward KL with chunked outputs that saves memory by only upcasting one chunk at a time. + + Since the model is trained with bf16, before computing KL divergence, we have to upcast + it to fp32 for better accuracy and stability. When upcasting happens, the memory usage doubles. + Models like llama3 have large vocabulary size and, therefore, have a large output + result (bsz, num_tokens, vocab_size). If we chunk on the token level, you can still compute + the cross entropy normally, but upcasting only one chunk at a time saves considerable memory. + """ + def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100): super().__init__() self.num_output_chunks = num_output_chunks @@ -43,6 +71,29 @@ def forward( teacher_logits: List[torch.Tensor], labels: torch.Tensor, ) -> torch.Tensor: + """ + Args: + student_logits (List[torch.Tensor]): List of chunked logits from student model of length + ``self.num_output_chunks``, where each chunk has shape + (batch_size, num_tokens / num_output_chunks, vocab_size). + teacher_logits (List[torch.Tensor]): List of chunked logits from teacher model of length + ``self.num_output_chunks``, where each chunk has shape + (batch_size, num_tokens / num_output_chunks, vocab_size). + labels (torch.Tensor): Ground truth labels of shape (batch_size, num_tokens). + + Returns: + torch.Tensor: KL divergence loss of shape (1,). + + Example: + >>> loss_fn = ForwardKLWithChunkedOutputLoss() + >>> + >>> h = torch.tensor([bsz, num_tokens, dim]) + >>> output_chunks = [model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)] + >>> teacher_chunks = [teacher_model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)] + >>> labels = torch.tensor([bsz, num_tokens]) + >>> loss = loss_fn(output_chunks, teacher_chunks, labels) + """ + # reshape logits [(bsz, num_tokens/num_chunks, vocab)] -> [(bsz*num_tokens/num_chunks, vocab)] teacher_logits = [ teacher_logits_chunk.reshape(-1, teacher_logits_chunk.size(-1)) @@ -62,12 +113,5 @@ def forward( student_logits, teacher_logits, labels ): total_fkl_loss += self.fkl_loss(student_chunk, teacher_chunk, label_chunk) - # teacher_prob = torch.nn.functional.softmax(teacher_chunk, dim=-1) - # inf_mask = torch.isinf(student_chunk) - # student_logprob = torch.nn.functional.log_softmax(student_chunk, dim=-1) - # prod_probs = torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0) - # x = torch.sum(prod_probs, dim=-1).view(-1) - # mask = (label_chunk != -100).int() - # total_fkl_loss += -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0) return total_fkl_loss / self.num_output_chunks From fe5ed97b8374e759b8e14650f3b58baf8c3b2f31 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Wed, 11 Sep 2024 08:26:33 -0700 Subject: [PATCH 10/37] added prereq command to config --- recipes/configs/llama3_1/kd_single_device.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/recipes/configs/llama3_1/kd_single_device.yaml b/recipes/configs/llama3_1/kd_single_device.yaml index b6eea11f57..3d1c2f5bd7 100644 --- a/recipes/configs/llama3_1/kd_single_device.yaml +++ b/recipes/configs/llama3_1/kd_single_device.yaml @@ -1,9 +1,10 @@ # Config for single device knowledge distillation in kd_single_device.py # using a teacher and student model # -# This config assumes that you've run the following command before launching +# This config assumes that you've run the following commands before launching # this run: # tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct +# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device # # To launch on a single device, run the following command from root: # tune run kd_single_device --config llama3_1/kd_single_device From 3f7fe704fb56761ad2d322bf96a5f2c9290f557b Mon Sep 17 00:00:00 2001 From: lindawangg Date: Wed, 11 Sep 2024 08:45:57 -0700 Subject: [PATCH 11/37] re-add 8B config --- .../llama3_1/8B_full_single_device.yaml | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 recipes/configs/llama3_1/8B_full_single_device.yaml diff --git a/recipes/configs/llama3_1/8B_full_single_device.yaml b/recipes/configs/llama3_1/8B_full_single_device.yaml new file mode 100644 index 0000000000..5da403804e --- /dev/null +++ b/recipes/configs/llama3_1/8B_full_single_device.yaml @@ -0,0 +1,80 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Llama3.1 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# The default config uses an optimizer from bitsandbytes. If you do not have it installed, +# you can install it with +# pip install bitsandbytes +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config llama3_1/8B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config llama3_1/8B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama3_1.llama3_1_8b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 2 +epochs: 3 +optimizer: + _component_: bitsandbytes.optim.PagedAdamW8bit + lr: 2e-5 +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +optimizer_in_bwd: True +compile: False + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.DiskLogger + log_dir: ${output_dir} +output_dir: /tmp/full-llama3.1-finetune +log_every_n_steps: 1 +log_peak_memory_stats: False From a87aa0c4cbf824fd78ee5b0725006c9e68b7c5cb Mon Sep 17 00:00:00 2001 From: lindawangg Date: Wed, 11 Sep 2024 12:51:09 -0700 Subject: [PATCH 12/37] added kd ratio --- recipes/configs/llama3_1/kd_single_device.yaml | 1 + recipes/kd_single_device.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/recipes/configs/llama3_1/kd_single_device.yaml b/recipes/configs/llama3_1/kd_single_device.yaml index 3d1c2f5bd7..19cc67cec0 100644 --- a/recipes/configs/llama3_1/kd_single_device.yaml +++ b/recipes/configs/llama3_1/kd_single_device.yaml @@ -88,6 +88,7 @@ loss: kd_loss: _component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss +kd_ratio: 0.5 # Training epochs: 1 diff --git a/recipes/kd_single_device.py b/recipes/kd_single_device.py index 82684c7514..408ccc3fed 100644 --- a/recipes/kd_single_device.py +++ b/recipes/kd_single_device.py @@ -136,6 +136,7 @@ def __init__(self, cfg: DictConfig) -> None: self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) self._gradient_accumulation_steps = cfg.gradient_accumulation_steps self._clip_grad_norm = cfg.get("clip_grad_norm", None) + self._kd_ratio = cfg.get("kd_ratio", 0.5) def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: """ @@ -663,7 +664,7 @@ def train(self) -> None: num_tokens += batch["tokens"].numel() class_loss, kd_loss = self._loss_step(batch) - loss = 0.5 * class_loss + 0.5 * kd_loss + loss = (1 - self._kd_ratio) * class_loss + self._kd_ratio * kd_loss loss = loss / self._gradient_accumulation_steps running_loss += loss running_class_loss += class_loss / self._gradient_accumulation_steps From f5feac4be999199ff3ff34ded6c970fd1218a050 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Wed, 11 Sep 2024 12:54:47 -0700 Subject: [PATCH 13/37] revert 8b config --- .../llama3_1/8B_full_single_device.yaml | 32 +++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/recipes/configs/llama3_1/8B_full_single_device.yaml b/recipes/configs/llama3_1/8B_full_single_device.yaml index 5da403804e..754c5e9fa4 100644 --- a/recipes/configs/llama3_1/8B_full_single_device.yaml +++ b/recipes/configs/llama3_1/8B_full_single_device.yaml @@ -24,6 +24,7 @@ tokenizer: _component_: torchtune.models.llama3.llama3_tokenizer path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model + max_seq_len: null # Dataset dataset: @@ -36,7 +37,7 @@ model: _component_: torchtune.models.llama3_1.llama3_1_8b checkpointer: - _component_: torchtune.utils.FullModelHFCheckpointer + _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ checkpoint_files: [ model-00001-of-00004.safetensors, @@ -56,7 +57,7 @@ optimizer: _component_: bitsandbytes.optim.PagedAdamW8bit lr: 2e-5 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 optimizer_in_bwd: True @@ -73,8 +74,33 @@ dtype: bf16 # Logging metric_logger: - _component_: torchtune.utils.metric_logging.DiskLogger + _component_: torchtune.training.metric_logging.DiskLogger log_dir: ${output_dir} output_dir: /tmp/full-llama3.1-finetune log_every_n_steps: 1 log_peak_memory_stats: False + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: True + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 1 + warmup_steps: 2 + active_steps: 1 + num_cycles: 1 From 8c3c42ac5af8753b7255db05f53986dcad7297bc Mon Sep 17 00:00:00 2001 From: lindawangg Date: Thu, 12 Sep 2024 12:12:17 -0700 Subject: [PATCH 14/37] add kd recipe test --- recipes/kd_single_device.py | 42 ++++- tests/recipes/test_kd_single_device.py | 223 +++++++++++++++++++++++++ torchtune/modules/loss/kd_losses.py | 2 + 3 files changed, 266 insertions(+), 1 deletion(-) create mode 100644 tests/recipes/test_kd_single_device.py diff --git a/recipes/kd_single_device.py b/recipes/kd_single_device.py index 408ccc3fed..8adfa9b78f 100644 --- a/recipes/kd_single_device.py +++ b/recipes/kd_single_device.py @@ -10,6 +10,7 @@ from functools import partial from typing import Any, Dict, Optional, Tuple, Union +from warnings import warn import torch from omegaconf import DictConfig, ListConfig @@ -166,11 +167,50 @@ def load_teacher_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any """ teacher_checkpointer = config.instantiate( cfg_checkpointer, - resume_from_checkpoint=self._resume_from_checkpoint, ) checkpoint_dict = teacher_checkpointer.load_checkpoint() return checkpoint_dict + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[training.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + def setup(self, cfg: DictConfig) -> None: """ Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), diff --git a/tests/recipes/test_kd_single_device.py b/tests/recipes/test_kd_single_device.py new file mode 100644 index 0000000000..49d55e3685 --- /dev/null +++ b/tests/recipes/test_kd_single_device.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import runpy +import sys +from pathlib import Path + +import pytest +import torch +from tests.common import TUNE_PATH +from tests.recipes.utils import ( + CKPT_COMPONENT_MAP, + dummy_alpaca_dataset_config, + MODEL_TEST_CONFIGS, + write_hf_ckpt_config, +) +from tests.test_utils import ( + CKPT_MODEL_PATHS, + gen_log_file_name, + get_loss_values_from_metric_logger, + TOKENIZER_PATHS, +) + + +class TestLoRAFinetuneSingleDeviceRecipe: + def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): + return [ + "batch_size=8", + "device=cpu", + f"dtype={dtype_str}", + "enable_activation_checkpointing=False", + "dataset.train_on_input=False", + "seed=9", + f"epochs={epochs}", + "max_steps_per_epoch=2", + "optimizer.lr=2e-5", + "log_every_n_steps=1", + "gradient_accumulation_steps=1", + "clip_grad_norm=100", + ] + dummy_alpaca_dataset_config() + + def _fetch_expected_loss_values(self, model_type): + loss_values_map = { + "llama3": [11.0651, 11.0577, 11.0540, 11.7671], + } + return loss_values_map[model_type] + + # @pytest.mark.integration_test + @pytest.mark.parametrize("compile", [True, False]) + @pytest.mark.parametrize( + "config, model_type, ckpt_type", + [ + ("llama3_1/kd_single_device", "llama3", "tune"), + ], + ) + def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch): + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + cmd = f""" + tune run kd_single_device \ + --config {config} \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + teacher_checkpointer._component_={ckpt_component} \ + teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ + teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ + teacher_checkpointer.output_dir={tmpdir} \ + teacher_checkpointer.model_type={model_type.upper()} \ + ~model.intermediate_dim \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ + metric_logger.filename={log_file} \ + compile={compile} \ + kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ + kd_ratio=0.5 \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] + teacher_config = [ + "teacher_" + config for config in MODEL_TEST_CONFIGS[model_type] + ] + + cmd = ( + cmd + + self._get_test_config_overrides(dtype_str="fp32") + + model_config + + teacher_config + ) + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Make sure to clear compile state in between tests + if compile: + torch._dynamo.reset() + + loss_values = get_loss_values_from_metric_logger(log_file) + # only take the first loss + num_losses = int(len(loss_values) / 4) # 2 steps per epoch, 2 epochs + loss_values = loss_values[0::num_losses] + expected_loss_values = self._fetch_expected_loss_values(model_type) + print(loss_values) + print(expected_loss_values) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) + + # @pytest.mark.integration_test + def test_training_state_on_resume(self, tmpdir, monkeypatch): + """Test whether the recipe state is correctly updated on resume. Since this + is model agnostic, we should run this on the small model only. The test + consists of three stages: + - Train a model for 2 epochs + - Resume training after epoch 1 + - Make sure final loss matches the expected value of a model successfully resumed from a ckpt + """ + + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + tokenizer_path = Path(TOKENIZER_PATHS["llama3"]) + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for two epochs + cmd_1 = f""" + tune run kd_single_device \ + --config llama3_1/kd_single_device \ + output_dir={tmpdir} \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ + teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ + teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ + teacher_checkpointer.output_dir={tmpdir} \ + teacher_checkpointer.model_type=LLAMA3 \ + ~model.intermediate_dim \ + tokenizer.path={tokenizer_path} \ + tokenizer.prompt_template=null \ + metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ + kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ + kd_ratio=0.5 \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + teacher_config = [ + "teacher_" + config for config in MODEL_TEST_CONFIGS["llama3"] + ] + + cmd_1 = ( + cmd_1 + self._get_test_config_overrides() + model_config + teacher_config + ) + monkeypatch.setattr(sys, "argv", cmd_1) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Resume training + cmd_2 = f""" + tune run kd_single_device \ + --config llama3_1/kd_single_device \ + output_dir={tmpdir} \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir={tmpdir} \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")} + checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ + teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ + teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ + teacher_checkpointer.output_dir={tmpdir} \ + teacher_checkpointer.model_type=LLAMA3 \ + ~model.intermediate_dim \ + resume_from_checkpoint=True \ + metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ + metric_logger.filename={log_file} \ + tokenizer.path={tokenizer_path} \ + tokenizer.prompt_template=null \ + kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ + kd_ratio=0.5 \ + """.split() + cmd_2 = ( + cmd_2 + + self._get_test_config_overrides(epochs=3) + + model_config + + teacher_config + ) + monkeypatch.setattr(sys, "argv", cmd_2) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Second epoch only + expected_loss_values = self._fetch_expected_loss_values("llama3")[2:] + loss_values = get_loss_values_from_metric_logger(log_file) + # only take the first loss + num_losses = int(len(loss_values) / 4) # 2 steps per epoch, 2 epochs + loss_values = loss_values[0::num_losses][:2] + + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) diff --git a/torchtune/modules/loss/kd_losses.py b/torchtune/modules/loss/kd_losses.py index bc62200b51..ada57fb6bc 100644 --- a/torchtune/modules/loss/kd_losses.py +++ b/torchtune/modules/loss/kd_losses.py @@ -45,6 +45,8 @@ def forward( prod_probs = torch.masked_fill(teacher_prob * student_logprob, inf_mask, 0) x = torch.sum(prod_probs, dim=-1).view(-1) mask = (labels != self.ignore_index).int() + if torch.sum(mask.view(-1), dim=0) == 0: + return torch.tensor(0.0, device=x.device) return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0) From 6ba0514a789137346bce4490ef66a6f2615614d1 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Thu, 12 Sep 2024 12:22:08 -0700 Subject: [PATCH 15/37] mark as integration test --- tests/recipes/test_kd_single_device.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/recipes/test_kd_single_device.py b/tests/recipes/test_kd_single_device.py index 49d55e3685..18a59498eb 100644 --- a/tests/recipes/test_kd_single_device.py +++ b/tests/recipes/test_kd_single_device.py @@ -49,7 +49,7 @@ def _fetch_expected_loss_values(self, model_type): } return loss_values_map[model_type] - # @pytest.mark.integration_test + @pytest.mark.integration_test @pytest.mark.parametrize("compile", [True, False]) @pytest.mark.parametrize( "config, model_type, ckpt_type", @@ -119,7 +119,7 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch) loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 ) - # @pytest.mark.integration_test + @pytest.mark.integration_test def test_training_state_on_resume(self, tmpdir, monkeypatch): """Test whether the recipe state is correctly updated on resume. Since this is model agnostic, we should run this on the small model only. The test From 04ea6496abf2690b93e1a35e97b7110fe2cd2025 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Thu, 12 Sep 2024 14:35:07 -0700 Subject: [PATCH 16/37] add save and load weights test --- tests/recipes/test_kd_single_device.py | 83 +++++++++++++++++++++++++- 1 file changed, 82 insertions(+), 1 deletion(-) diff --git a/tests/recipes/test_kd_single_device.py b/tests/recipes/test_kd_single_device.py index 18a59498eb..29e3651773 100644 --- a/tests/recipes/test_kd_single_device.py +++ b/tests/recipes/test_kd_single_device.py @@ -11,6 +11,7 @@ import pytest import torch +from omegaconf import OmegaConf from tests.common import TUNE_PATH from tests.recipes.utils import ( CKPT_COMPONENT_MAP, @@ -24,9 +25,10 @@ get_loss_values_from_metric_logger, TOKENIZER_PATHS, ) +from torchtune import config -class TestLoRAFinetuneSingleDeviceRecipe: +class TestKDSingleDeviceRecipe: def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): return [ "batch_size=8", @@ -221,3 +223,82 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 ) + + @pytest.mark.integration_test + def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): + ckpt_type = "tune" + model_type = "llama3" + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + cmd = f""" + tune run kd_single_device \ + --config llama3_1/kd_single_device \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + teacher_checkpointer._component_={ckpt_component} \ + teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ + teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ + teacher_checkpointer.output_dir={tmpdir} \ + teacher_checkpointer.model_type={model_type.upper()} \ + ~model.intermediate_dim \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ + metric_logger.filename={log_file} \ + kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ + kd_ratio=0.5 \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] + teacher_config = [ + "teacher_" + config for config in MODEL_TEST_CONFIGS[model_type] + ] + + cmd = ( + cmd + + self._get_test_config_overrides(dtype_str="fp32") + + model_config + + teacher_config + ) + monkeypatch.setattr(sys, "argv", cmd) + with pytest.raises(SystemExit, match=""): + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Next load both the merged weights in a Llama3 base model + # and the base model weights + trained adapter weights in the LoRA Llama 3 model + # The results of calling forward on dummy inputs should be the same. + inputs = torch.randint(low=0, high=32_000, size=(2, 100)) + + # Build LoRA model for loading base + adapter weights separately + lora_model = config.instantiate(OmegaConf.from_dotlist(model_config).model) + + # Build base llama3 model for loading merged weights + base_llama3_config = MODEL_TEST_CONFIGS[model_type] + llama3_model = config.instantiate( + OmegaConf.from_dotlist(base_llama3_config).model + ) + + # Load base model and trained adapter weights into LoRA model and call fwd + with open(f"{tmpdir}/adapter_1.pt", "rb") as f: + lora_sd = torch.load(f, weights_only=True) + with open(ckpt_path, "rb") as f: + base_model_sd = torch.load(f, weights_only=True) + lora_model.load_state_dict(lora_sd, strict=False) + lora_model.load_state_dict(base_model_sd, strict=False) + baseline_out = lora_model(inputs) + + # Load merged final ckpt directly into 3 and call fwd + with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f: + sd = torch.load(f, weights_only=True) + llama3_model.load_state_dict(sd) + merged_ckpt_out = llama3_model(inputs) + torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5) From 62faa1d073c926a2c345b5772736b1e91eda1105 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Thu, 12 Sep 2024 19:13:57 -0700 Subject: [PATCH 17/37] fix comments 1 --- torchtune/modules/loss/kd_losses.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/torchtune/modules/loss/kd_losses.py b/torchtune/modules/loss/kd_losses.py index ada57fb6bc..bdeff3aa5c 100644 --- a/torchtune/modules/loss/kd_losses.py +++ b/torchtune/modules/loss/kd_losses.py @@ -13,7 +13,12 @@ class ForwardKLLoss(torch.nn.Module): """ The Kullback-Leibler divergence loss for valid indexes. - Implementation of https://github.com/jongwooko/distillm/blob/master/distillm/losses.py. + Implementation of https://github.com/jongwooko/distillm/blob/17c0f98bc263b1861a02d5df578c84aea652ee65/distillm/losses.py + + Args: + ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. + The loss is divided over non-ignored targets. + Default: -100. """ def __init__(self, ignore_index: int = -100): @@ -59,6 +64,14 @@ class ForwardKLWithChunkedOutputLoss(torch.nn.Module): Models like llama3 have large vocabulary size and, therefore, have a large output result (bsz, num_tokens, vocab_size). If we chunk on the token level, you can still compute the cross entropy normally, but upcasting only one chunk at a time saves considerable memory. + + Args: + num_output_chunks (int): Number of chunks to chunk the output into. Each chunk has shape + (batch_size, num_tokens / num_output_chunks, vocab_size). + Default: 8 + ignore_index (int): Specifies a target value that is ignored and does not contribute to the input gradient. + The loss is divided over non-ignored targets. + Default: -100 """ def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100): From bf15406f623e9d26047097804e6b3492f87c368e Mon Sep 17 00:00:00 2001 From: lindawangg Date: Thu, 12 Sep 2024 20:44:44 -0700 Subject: [PATCH 18/37] address kd loss test comments --- .../torchtune/modules/loss/test_kd_losses.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/torchtune/modules/loss/test_kd_losses.py b/tests/torchtune/modules/loss/test_kd_losses.py index 28ee57e0a3..6903f696e8 100644 --- a/tests/torchtune/modules/loss/test_kd_losses.py +++ b/tests/torchtune/modules/loss/test_kd_losses.py @@ -54,3 +54,63 @@ def test_forward_kl_loss(self): # Assert assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) + + def test_forward_kl_loss_expected(self): + student_logits = torch.tensor( + [ + [ + [1.1250, -0.4102, -0.0879, -2.5000], + [0.2676, 0.3535, 0.8711, -1.4688], + [-0.1084, 1.6641, 0.0084, 0.1196], + [0.5000, -0.6406, -0.2236, -1.5938], + ], + [ + [-1.5312, -1.9219, 0.0000, -0.5039], + [-1.5391, 1.5312, 0.5820, 0.2695], + [-0.3887, 1.2188, 0.0000, 0.6055], + [0.5000, 1.3828, 0.1309, -1.0312], + ], + ], + dtype=torch.bfloat16, + ) + teacher_logits = torch.tensor( + [ + [ + [-0.0381, -1.2578, -1.2031, 0.0947], + [-0.7852, 0.4492, 1.5547, 0.0972], + [0.8203, 0.0012, 0.7656, 0.3477], + [-1.5781, 0.4297, 0.5977, 0.3926], + ], + [ + [1.5156, 0.1641, 2.0781, -0.7734], + [-0.5898, 0.4453, -0.7969, 0.6328], + [0.6289, -0.8359, 0.9258, 0.2109], + [0.0006, 0.5195, 3.2344, -1.5781], + ], + ], + dtype=torch.bfloat16, + ) + labels = torch.tensor([[0, 3, 3, 1], [1, 1, 1, 1]]) + expected_loss = torch.tensor(1.7209, dtype=torch.float32) + + # chunked FKL loss + chunked_fkl_loss = ForwardKLWithChunkedOutputLoss( + num_output_chunks=2, ignore_index=-100 + ) + student_logits_chunks = student_logits.chunk( + chunked_fkl_loss.num_output_chunks, dim=1 + ) + teacher_logits_chunks = teacher_logits.chunk( + chunked_fkl_loss.num_output_chunks, dim=1 + ) + chunked_loss = chunked_fkl_loss( + student_logits_chunks, teacher_logits_chunks, labels + ) + + # vanilla FKL loss + fkl_loss = ForwardKLLoss(ignore_index=-100) + standard_loss = fkl_loss(student_logits, teacher_logits, labels) + + # assert + assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2) + assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2) From ac9eb0e542bf5a1fde68f80b71f0aed15bd3418c Mon Sep 17 00:00:00 2001 From: lindawangg Date: Thu, 12 Sep 2024 23:41:41 -0700 Subject: [PATCH 19/37] change to qwen2 --- .../configs/llama3_1/kd_single_device.yaml | 110 ------------------ recipes/configs/qwen2/kd_single_device.yaml | 97 +++++++++++++++ tests/recipes/test_kd_single_device.py | 20 ++-- torchtune/_recipe_registry.py | 4 + 4 files changed, 113 insertions(+), 118 deletions(-) delete mode 100644 recipes/configs/llama3_1/kd_single_device.yaml create mode 100644 recipes/configs/qwen2/kd_single_device.yaml diff --git a/recipes/configs/llama3_1/kd_single_device.yaml b/recipes/configs/llama3_1/kd_single_device.yaml deleted file mode 100644 index 19cc67cec0..0000000000 --- a/recipes/configs/llama3_1/kd_single_device.yaml +++ /dev/null @@ -1,110 +0,0 @@ -# Config for single device knowledge distillation in kd_single_device.py -# using a teacher and student model -# -# This config assumes that you've run the following commands before launching -# this run: -# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct -# tune run lora_finetune_single_device --config llama3_1/8B_lora_single_device -# -# To launch on a single device, run the following command from root: -# tune run kd_single_device --config llama3_1/kd_single_device -# -# This config works only for distilling on a single device. - - -# Model Arguments -model: - _component_: torchtune.models.llama3_1.lora_llama3_1 - lora_attn_modules: ['q_proj', 'v_proj'] - apply_lora_to_mlp: False - apply_lora_to_output: False - vocab_size: 128256 - num_layers: 16 - num_heads: 32 - num_kv_heads: 8 - embed_dim: 2048 - max_seq_len: 131072 - intermediate_dim: 8192 - attn_dropout: 0.0 - norm_eps: 1e-5 - rope_base: 500000.0 - lora_rank: 8 - lora_alpha: 16 - lora_dropout: 0.05 - use_dora: False - quantize_base: False - -teacher_model: - _component_: torchtune.models.llama3_1.llama3_1_8b - -# Tokenizer -tokenizer: - _component_: torchtune.models.llama3.llama3_tokenizer - path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model - max_seq_len: null - -checkpointer: - _component_: torchtune.training.FullModelMetaCheckpointer - checkpoint_dir: /tmp/Llama-3.1-Student/ - checkpoint_files: [ - consolidated.00.pth - ] - recipe_checkpoint: null - output_dir: /tmp/kd_model/ - model_type: LLAMA3 - -# Teacher checkpoint -teacher_checkpointer: - _component_: torchtune.training.FullModelMetaCheckpointer - checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/lora_finetuned_single_device_epoch_1/ - checkpoint_files: [ - meta_model_0.pt - ] - recipe_checkpoint: null - output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ - model_type: LLAMA3 - -resume_from_checkpoint: False -save_adapter_weights_only: False - -# Dataset and Sampler -dataset: - _component_: torchtune.datasets.alpaca_cleaned_dataset -seed: null -shuffle: True -batch_size: 8 - -# Optimizer and Scheduler -optimizer: - _component_: torch.optim.AdamW - weight_decay: 0.01 - lr: 3e-4 -lr_scheduler: - _component_: torchtune.modules.get_cosine_schedule_with_warmup - num_warmup_steps: 100 - -loss: - _component_: torchtune.modules.loss.CEWithChunkedOutputLoss - -kd_loss: - _component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss -kd_ratio: 0.5 - -# Training -epochs: 1 -max_steps_per_epoch: null -gradient_accumulation_steps: 16 -compile: False - -# Logging -output_dir: /tmp/kd_output -metric_logger: - _component_: torchtune.training.metric_logging.TensorBoardLogger - log_dir: ${output_dir} -log_every_n_steps: 1 -log_peak_memory_stats: False - -# Environment -device: cuda -dtype: bf16 -enable_activation_checkpointing: False diff --git a/recipes/configs/qwen2/kd_single_device.yaml b/recipes/configs/qwen2/kd_single_device.yaml new file mode 100644 index 0000000000..99bf7872bf --- /dev/null +++ b/recipes/configs/qwen2/kd_single_device.yaml @@ -0,0 +1,97 @@ +# Config for single device knowledge distillation in kd_single_device.py +# using a teacher and student model +# +# This config assumes that you've ran the following commands before launching KD: +# First download the student and teacher models +# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct --ignore-patterns None +# tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None +# +# You get better results using KD if the teacher model has already been fine-tuned on the target dataset: +# tune run lora_finetune_single_device --config qwen2/1.5B_lora_single_device +# +# To launch on a single device, run the following command from root: +# tune run kd_single_device --config qwen2/kd_single_device +# +# This config works only for distilling on a single device. + + +# Model Arguments +model: + _component_: torchtune.models.qwen2.lora_qwen2_0_5b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: False + lora_rank: 32 + lora_alpha: 64 + +teacher_model: + _component_: torchtune.models.qwen2.qwen2_1_5b + +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-0.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt + max_seq_len: null + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-0.5B-Instruct + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-0.5B-Instruct-kd + model_type: QWEN2 + +teacher_checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune + checkpoint_files: [ + hf_model_0001_0.pt + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune + model_type: QWEN2 + +resume_from_checkpoint: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 16 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +kd_loss: + _component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss +kd_ratio: 0.5 + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 1 +compile: False + +# Logging +output_dir: /tmp/qwen_kd +metric_logger: + _component_: torchtune.training.metric_logging.TensorBoardLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False diff --git a/tests/recipes/test_kd_single_device.py b/tests/recipes/test_kd_single_device.py index 29e3651773..b1ad8826b7 100644 --- a/tests/recipes/test_kd_single_device.py +++ b/tests/recipes/test_kd_single_device.py @@ -56,7 +56,7 @@ def _fetch_expected_loss_values(self, model_type): @pytest.mark.parametrize( "config, model_type, ckpt_type", [ - ("llama3_1/kd_single_device", "llama3", "tune"), + ("qwen2/kd_single_device", "llama3", "tune"), ], ) def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch): @@ -81,9 +81,10 @@ def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch) teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ teacher_checkpointer.output_dir={tmpdir} \ teacher_checkpointer.model_type={model_type.upper()} \ - ~model.intermediate_dim \ + tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path='{tokenizer_path}' \ tokenizer.prompt_template=null \ + ~tokenizer.merges_file \ metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ metric_logger.filename={log_file} \ compile={compile} \ @@ -145,7 +146,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): # Train for two epochs cmd_1 = f""" tune run kd_single_device \ - --config llama3_1/kd_single_device \ + --config qwen2/kd_single_device \ output_dir={tmpdir} \ checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -157,9 +158,10 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ teacher_checkpointer.output_dir={tmpdir} \ teacher_checkpointer.model_type=LLAMA3 \ - ~model.intermediate_dim \ + tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path={tokenizer_path} \ tokenizer.prompt_template=null \ + ~tokenizer.merges_file \ metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ kd_ratio=0.5 \ @@ -180,7 +182,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): # Resume training cmd_2 = f""" tune run kd_single_device \ - --config llama3_1/kd_single_device \ + --config qwen2/kd_single_device \ output_dir={tmpdir} \ checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir={tmpdir} \ @@ -194,12 +196,13 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ teacher_checkpointer.output_dir={tmpdir} \ teacher_checkpointer.model_type=LLAMA3 \ - ~model.intermediate_dim \ resume_from_checkpoint=True \ metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ metric_logger.filename={log_file} \ + tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path={tokenizer_path} \ tokenizer.prompt_template=null \ + ~tokenizer.merges_file \ kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ kd_ratio=0.5 \ """.split() @@ -237,7 +240,7 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): cmd = f""" tune run kd_single_device \ - --config llama3_1/kd_single_device \ + --config qwen2/kd_single_device \ output_dir={tmpdir} \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -249,9 +252,10 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ teacher_checkpointer.output_dir={tmpdir} \ teacher_checkpointer.model_type={model_type.upper()} \ - ~model.intermediate_dim \ + tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path='{tokenizer_path}' \ tokenizer.prompt_template=null \ + ~tokenizer.merges_file \ metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ metric_logger.filename={log_file} \ kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 0e0b95b010..825e5c7719 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -285,6 +285,10 @@ class Recipe: name="llama3_1/kd_single_device", file_path="llama3_1/kd_single_device.yaml", ), + Config( + name="qwen2/kd_single_device", + file_path="qwen2/kd_single_device.yaml", + ), ], supports_distributed=False, ), From 87a80b6a9f7b992149fb0e463f08bfa440d78132 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Sat, 14 Sep 2024 23:54:37 -0700 Subject: [PATCH 20/37] addressing recipe comments --- recipes/configs/qwen2/kd_single_device.yaml | 6 +-- recipes/kd_single_device.py | 50 +++++++++------------ torchtune/_recipe_registry.py | 4 -- 3 files changed, 24 insertions(+), 36 deletions(-) diff --git a/recipes/configs/qwen2/kd_single_device.yaml b/recipes/configs/qwen2/kd_single_device.yaml index 99bf7872bf..a596ba7aad 100644 --- a/recipes/configs/qwen2/kd_single_device.yaml +++ b/recipes/configs/qwen2/kd_single_device.yaml @@ -59,7 +59,7 @@ dataset: _component_: torchtune.datasets.alpaca_cleaned_dataset seed: null shuffle: True -batch_size: 16 +batch_size: 8 # Optimizer and Scheduler optimizer: @@ -80,13 +80,13 @@ kd_ratio: 0.5 # Training epochs: 1 max_steps_per_epoch: null -gradient_accumulation_steps: 1 +gradient_accumulation_steps: 2 compile: False # Logging output_dir: /tmp/qwen_kd metric_logger: - _component_: torchtune.training.metric_logging.TensorBoardLogger + _component_: torchtune.training.metric_logging.DiskLogger log_dir: ${output_dir} log_every_n_steps: 1 log_peak_memory_stats: False diff --git a/recipes/kd_single_device.py b/recipes/kd_single_device.py index 8adfa9b78f..14b20e33ee 100644 --- a/recipes/kd_single_device.py +++ b/recipes/kd_single_device.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os import sys import time @@ -13,6 +12,7 @@ from warnings import warn import torch +import torchtune.modules.common_utils as common_utils from omegaconf import DictConfig, ListConfig from torch import nn @@ -222,12 +222,14 @@ def setup(self, cfg: DictConfig) -> None: # log config with parameter override self._metric_logger.log_config(cfg) - self._model_compile = cfg.compile + self._compile = cfg.compile checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) teacher_checkpoint_dict = self.load_teacher_checkpoint( cfg_checkpointer=cfg.teacher_checkpointer ) + common_utils._use_low_cpu_ram = cfg.get("low_cpu_ram", False) + # set up model self._model = self._setup_model( cfg_model=cfg.model, @@ -261,26 +263,19 @@ def setup(self, cfg: DictConfig) -> None: # initialize loss self._loss_fn = config.instantiate(cfg.loss) self._kd_loss_fn = config.instantiate(cfg.kd_loss) - backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + if self._compile: + self._loss_fn = training.compile_loss(self._loss_fn) + # TODO: compile kd_loss_fn + self._kd_loss_fn = training.compile_loss(self._kd_loss_fn) if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + self._teacher_model.set_num_output_chunks(self._loss_fn.num_output_chunks) + # assert _loss_fn and _kd_loss_fn have the same num_output_chunks assert ( self._loss_fn.num_output_chunks == self._kd_loss_fn.num_output_chunks ), "Number of output chunks for loss_fn and kd_loss_fn must be the same." - self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) - self._teacher_model.set_num_output_chunks(self._loss_fn.num_output_chunks) - if self._model_compile: - log.info("Compiling loss with torch.compile...") - # For CEWithChunkedOutputLoss, if we compile the entire class - # we lose the benefits from the chunked loss. - # Therefore, we only compile the cross entropy function + upcasting - self._loss_fn.compute_cross_entropy = torch.compile( - self._loss_fn.compute_cross_entropy, backend=backend - ) - else: - if self._model_compile: - log.info("Compiling loss with torch.compile...") - self._loss_fn = torch.compile(self._loss_fn, backend=backend) + log.info("Loss is initialized.") # Dataloader depends on the tokenizer and loss_fn and should be @@ -413,11 +408,7 @@ def _setup_model( set_trainable_params(model, self.adapter_params) if compile_model: - log.info("Compiling model layers with torch.compile...") - backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") - for m in reversed(list(model.modules())): - if isinstance(m, modules.transformer.TransformerSelfAttentionLayer): - m.compile(backend=backend) + training.compile_model(model) if enable_activation_checkpointing: utils.set_activation_checkpointing( @@ -664,14 +655,13 @@ def train(self) -> None: The core training loop. """ - if self._model_compile: + if self._compile: log.info( "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." ) # Initialize tokens count and running loss (for grad accumulation) t0 = time.perf_counter() - running_loss = 0 running_class_loss = 0 running_kd_loss = 0 num_tokens = 0 @@ -706,7 +696,6 @@ def train(self) -> None: class_loss, kd_loss = self._loss_step(batch) loss = (1 - self._kd_ratio) * class_loss + self._kd_ratio * kd_loss loss = loss / self._gradient_accumulation_steps - running_loss += loss running_class_loss += class_loss / self._gradient_accumulation_steps running_kd_loss += kd_loss / self._gradient_accumulation_steps loss.backward() @@ -724,7 +713,11 @@ def train(self) -> None: # Update the number of steps when the weights are updated self.global_step += 1 - loss_to_log = running_loss.item() + class_loss_to_log = running_class_loss.item() + kd_loss_to_log = running_kd_loss.item() + loss_to_log = ( + 1 - self._kd_ratio + ) * class_loss_to_log + self._kd_ratio * kd_loss_to_log pbar.update(1) pbar.set_description( f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" @@ -735,8 +728,8 @@ def train(self) -> None: time_per_step = time.perf_counter() - t0 log_dict = { "loss": loss_to_log, - "class_loss": running_class_loss.item(), - "kd_loss": running_kd_loss.item(), + "class_loss": class_loss_to_log, + "kd_loss": kd_loss_to_log, "lr": self._optimizer.param_groups[0]["lr"], "tokens_per_second_per_gpu": num_tokens / time_per_step, } @@ -755,7 +748,6 @@ def train(self) -> None: ) # Reset running stats for the next step - running_loss = 0 running_class_loss = 0 running_kd_loss = 0 num_tokens = 0 diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 825e5c7719..13f98c22fe 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -281,10 +281,6 @@ class Recipe: name="kd_single_device", file_path="kd_single_device.py", configs=[ - Config( - name="llama3_1/kd_single_device", - file_path="llama3_1/kd_single_device.yaml", - ), Config( name="qwen2/kd_single_device", file_path="qwen2/kd_single_device.yaml", From 106aa3eb23d680197d013940c5ef498d2e785211 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Mon, 16 Sep 2024 13:59:03 -0700 Subject: [PATCH 21/37] distributed recipe --- recipes/kd_distributed.py | 915 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 915 insertions(+) create mode 100644 recipes/kd_distributed.py diff --git a/recipes/kd_distributed.py b/recipes/kd_distributed.py new file mode 100644 index 0000000000..f5e966d267 --- /dev/null +++ b/recipes/kd_distributed.py @@ -0,0 +1,915 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time + +from functools import partial +from typing import Any, Dict, Optional, Tuple, Union +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.distributed import destroy_process_group, init_process_group + +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler +from torchtune import config, modules, training, utils +from torchtune.data import padded_collate_packed, padded_collate_sft +from torchtune.datasets import ConcatDataset +from torchtune.modules.peft import ( + DoRALinear, + get_lora_module_names, + get_merged_lora_ckpt, + load_dora_magnitudes, + LoRALinear, + set_trainable_params, + validate_missing_and_unexpected_for_lora, +) +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.training import DummyProfiler, PROFILER_KEY + +from tqdm import tqdm + +log = utils.get_logger("DEBUG") + + +class KDRecipeDistributed(FTRecipeInterface): + """ + Knowledge distillation recipe for dense transformer-based LLMs such as Llama3. This recipe is optimized + for single GPU training. Training on CPU is not supported. + + Features: + - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported.g + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * gradient accumulation steps. + + For example: with batch_size=1 and gradient_accumulation_steps=32 we get a total batch size of 32. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Lower precision optimizers. This recipe supports lower-precision optimizers from the bitsandbytes + library (https://huggingface.co/docs/bitsandbytes/main/en/index). We've tested the recipe with + 8-bit AdamW and Paged AdamW. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Currently we checkpoint both the adapter weights (trainable params only) and the + complete merged weights (adapter weights added back to the base model). For more details + please take a look at our LoRA tutorial + (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html). + + Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. Resuming + training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + + """ + + def __init__(self, cfg: DictConfig) -> None: + self._device = utils.get_device(device=cfg.device) + # Reduced precision logic + self._dtype = training.get_dtype(cfg.dtype, device=self._device) + # fp16 precision is explicitly disabled as it is not supported in this + # recipe (for example, no gradient scaling). + if self._dtype == torch.float16: + raise ValueError( + "fp16 precision is not supported in this recipe. Please use fp32 or bf16." + ) + + _, rank = training.get_world_size_and_rank() + + # _is_rank_zero is used primarily for logging. In the future, the logger + # should directly take care of this + self._is_rank_zero = rank == 0 + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # training attributes + self._enable_activation_checkpointing = cfg.enable_activation_checkpointing + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = training.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + self._kd_ratio = cfg.get("kd_ratio", 0.5) + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. This includes the + base model weights. If resume_from_checkpoint is True, this also includes + the adapter weights and recipe state + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + if training.ADAPTER_KEY not in checkpoint_dict: + raise ValueError( + "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." + ) + # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded + # no need to check here + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def load_teacher_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the teacher checkpoint state from file. + """ + teacher_checkpointer = config.instantiate( + cfg_checkpointer, + ) + checkpoint_dict = teacher_checkpointer.load_checkpoint() + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[training.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[training.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[training.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), + model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. + """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + self._compile = cfg.get("compile", False) + checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + teacher_checkpoint_dict = self.load_teacher_checkpoint( + cfg_checkpointer=cfg.teacher_checkpointer + ) + + # set up model + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), + base_model_state_dict=checkpoint_dict[training.MODEL_KEY], + lora_weights_state_dict=( + checkpoint_dict[training.ADAPTER_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + self._teacher_model = self._setup_teacher_model( + model_cfg=cfg.teacher_model, + model_state_dict=teacher_checkpoint_dict[training.MODEL_KEY], + ) + + self._tokenizer = config.instantiate(cfg.tokenizer) + log.info("Tokenizer is initialized from file.") + + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=( + checkpoint_dict[training.OPT_KEY] + if self._resume_from_checkpoint + else None + ), + ) + + # initialize loss + self._loss_fn = config.instantiate(cfg.loss) + self._kd_loss_fn = config.instantiate(cfg.kd_loss) + if self._compile: + self._loss_fn = training.compile_loss( + self._loss_fn, verbose=self._is_rank_zero + ) + self._kd_loss_fn = training.compile_loss( + self._kd_loss_fn, verbose=self._is_rank_zero + ) + + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + self._teacher_model.set_num_output_chunks(self._loss_fn.num_output_chunks) + # assert _loss_fn and _kd_loss_fn have the same num_output_chunks + assert ( + self._loss_fn.num_output_chunks == self._kd_loss_fn.num_output_chunks + ), "Number of output chunks for loss_fn and kd_loss_fn must be the same." + + if self._is_rank_zero: + log.info("Loss is initialized.") + + # Dataloader depends on the tokenizer and loss_fn and should be + # setup after all of these are setup + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader and the max_steps_per_epoch param set by the user and is used + # for logging and tracking training state. This should be computed after the dataloader + # has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + # Learning rate scheduler can only be set up after number of steps + # has been computed + self._lr_scheduler = self._setup_lr_scheduler( + cfg_lr_scheduler=cfg.lr_scheduler, + num_training_steps=self.total_epochs * self._steps_per_epoch, + last_epoch=self.global_step - 1, + ) + + # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) + # if cfg is missing profiler key or if `cfg.profiler.enabled = False + self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + + def _setup_profiler( + self, cfg_profiler: Optional[DictConfig] = None + ) -> Union[torch.profiler.profile, DummyProfiler]: + """ + Parses the `profiler` section of top-level `cfg` and sets up profiler + + Args: + cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to + `recipe.main`). Default None. + + Returns: + profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods + for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such + that the instrumented training loop does not need to be changed profiling is disabled. + + The profiler config can be provided in configs under the `profiler` key with the following layout: + + .. code-block:: yaml + profiler: + enabled: bool + + #Output directory of trace artifacts + output_dir: str + + #`torch.profiler.ProfilerActivity` types to trace + cpu: bool + cuda: bool + + #Trace options + profile_memory: bool + with_stack: bool + record_shapes: bool + with_flops: bool + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: int + warmup_steps: int + active_steps: int + num_cycles: int + """ + + # Missing profiler section in config, assume disabled + if cfg_profiler is None: + cfg_profiler = DictConfig({"enabled": False}) + + # Check that component is included and set correctly + if cfg_profiler.get("_component_", None) is None: + cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" + else: + assert ( + cfg_profiler.get("_component_") + == "torchtune.training.setup_torch_profiler" + ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" + + profiler, profiler_cfg = config.instantiate(cfg_profiler) + + if self._is_rank_zero: + log.info(f" Profiler config after instantiation: {profiler_cfg}") + + self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) + if profiler_cfg["enabled"]: + self.profiler_wait_steps = profiler_cfg["wait_steps"] + self.profiler_warmup_steps = profiler_cfg["warmup_steps"] + self.profiler_active_steps = profiler_cfg["active_steps"] + + return profiler + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + fsdp_cpu_offload: bool, + reshard_after_forward: bool, + base_model_state_dict: Dict[str, Any], + lora_weights_state_dict: Optional[Dict[str, Any]] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` + c. We register (pre-)forward hooks with ``fully_shard`` instead of wrapping `nn.Module` + """ + + self._lora_rank = cfg_model.lora_rank + self._lora_alpha = cfg_model.lora_alpha + self._lora_attn_modules = list(cfg_model.lora_attn_modules) + self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp + self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) + + if self._is_rank_zero: + log.info( + "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." + ) + init_start = time.perf_counter() + + with training.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + self.adapter_params = training.get_adapter_params(model) + set_trainable_params(model, self.adapter_params) + + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + + if enable_activation_checkpointing: + training.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} + ) + + # For FSDP sharding, we can condition on either the module or its name + # Shard conditions should be callables taking name (relative to model root) + # and the module itself and returning a bool on whether to shard the given module + + # Shard transformer decoder layers (or AC-wrapped versions) + # Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) + # But directly using the name is more concise + def _is_layer_name(name: str, module: nn.Module) -> bool: + """ + Return True for layers.i and False for all other module names + Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot + """ + name_list = name.split(".") + return ( + len(name_list) == 2 + and name_list[0] == "layers" + and str.isdigit(name_list[1]) + ) + + training.shard_model( + model=model, + shard_conditions=[_is_layer_name], + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + ) + + if lora_weights_state_dict: + lora_missing, lora_unexpected = training.load_from_full_model_state_dict( + model, + lora_weights_state_dict, + self._device, + self._is_rank_zero, + cpu_offload=fsdp_cpu_offload, + ) + else: + lora_missing, lora_unexpected = None, None + + # Initializer for LoRA params and RoPE buffers + with training.set_default_dtype(self._dtype), self._device: + lora_device = "cpu" if fsdp_cpu_offload else self._device + for m in model.modules(): + if ( + isinstance(m, LoRALinear) or isinstance(m, DoRALinear) + ) and not lora_weights_state_dict: + # lora may not be covered in state dict + # if finetune for the 1st time + m.lora_a.to_empty(device=lora_device) + m.lora_b.to_empty(device=lora_device) + m.initialize_parameters() + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + base_missing, base_unexpected = training.load_from_full_model_state_dict( + model, + base_model_state_dict, + self._device, + self._is_rank_zero, + cpu_offload=fsdp_cpu_offload, + ) + is_dora = False + for m in model.modules(): + if hasattr(m, "initialize_dora_magnitude"): + is_dora = (True,) + m.initialize_dora_magnitude() + if is_dora: + load_dora_magnitudes(model) + validate_missing_and_unexpected_for_lora( + lora_attn_modules=self._lora_attn_modules, + apply_lora_to_mlp=self._apply_lora_to_mlp, + apply_lora_to_output=self._apply_lora_to_output, + base_missing=base_missing, + base_unexpected=base_unexpected, + lora_missing=lora_missing, + lora_unexpected=lora_unexpected, + ) + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) + + if self._is_rank_zero: + log.info( + f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + ) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_teacher_model( + self, + model_cfg: DictConfig, + model_state_dict: Dict[str, Any], + ) -> nn.Module: + with training.set_default_dtype(self._dtype), self._device: + model = config.instantiate(model_cfg) + + model.load_state_dict(model_state_dict) + + # Put model in eval mode. + # Note: This will not disable the dropout applied in SDPA, + # see https://github.com/pytorch/pytorch/issues/124464 + model.eval() + + # Validate model was loaded in with the expected dtype. + training.validate_expected_param_dtype( + model.named_parameters(), dtype=self._dtype + ) + if self._is_rank_zero: + log.info(f"Teacher model is initialized with precision {self._dtype}.") + return model + + def _setup_optimizer( + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + ) -> Optimizer: + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + if opt_state_dict: + training.load_from_full_optimizer_state_dict( + optimizer, + opt_state_dict, + self._device, + ) + + if self._is_rank_zero: + log.info("Optimizer is initialized.") + return optimizer + + def _setup_lr_scheduler( + self, + cfg_lr_scheduler: DictConfig, + num_training_steps: int, + last_epoch: int, + ) -> Optimizer: + lr_scheduler = config.instantiate( + cfg_lr_scheduler, + self._optimizer, + num_training_steps=num_training_steps, + last_epoch=last_epoch, + ) + + if self._is_rank_zero: + log.info("Learning rate scheduler is initialized.") + return lr_scheduler + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports + Map-style Datasets which fit into memory and an option for random shuffling. + Samplers, iterable datasets, and streaming datasets are not supported. + """ + world_size, rank = training.get_world_size_and_rank() + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = False + else: + ds = config.instantiate(cfg_dataset, self._tokenizer) + packed = cfg_dataset.get("packed", False) + + sampler = DistributedSampler( + ds, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + collate_fn=partial( + padded_collate_sft, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else partial( + padded_collate_packed, + ), + ) + + if self._is_rank_zero: + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint(self, epoch: int) -> None: + """ + Checkpoint the state of the recipe. The constructed checkpoint state dict + contains the following information: + - Merged weights with key MODEL_KEY + - Adapter weights with key ADAPTER_KEY + - Relevant recipe state if training is not complete + - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights + + To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. + """ + # final dict passed onto the checkpointer + checkpoint_dict = {} + + intermediate_checkpoint = epoch + 1 < self.total_epochs + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + cpu_state_dict = training.get_full_model_state_dict( + self._model, + self._is_rank_zero, + device=self._device, + ) + + if intermediate_checkpoint: + opt_state_dict = training.get_full_optimizer_state_dict( + self._optimizer, + self._is_rank_zero, + device=self._device, + ) + else: + opt_state_dict = None + + # Now that we have the model and opt state dict, create the actual checkpoint + # to be sent to the checkpointer and ultimately written to file + if self._is_rank_zero: + + # Filter out the adapter keys and weights from the model state dict. These will + # be saved separately + adapter_key_filter = lambda x: x in self.adapter_params + adapter_state_dict = { + k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k) + } + checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict}) + + # merge the adapter weights and base weights to create the model checkpoint + merged_state_dict = get_merged_lora_ckpt( + cpu_state_dict, + rank=self._lora_rank, + alpha=self._lora_alpha, + ) + checkpoint_dict.update({training.MODEL_KEY: merged_state_dict}) + + # if training is in-progress, checkpoint the optimizer state and recipe state + # as well + if intermediate_checkpoint: + checkpoint_dict.update( + { + training.OPT_KEY: opt_state_dict, + training.SEED_KEY: self.seed, + training.EPOCHS_KEY: self.epochs_run, + training.TOTAL_EPOCHS_KEY: self.total_epochs, + training.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + adapter_config = { + "r": self._lora_rank, + "lora_alpha": self._lora_alpha, + "target_modules": get_lora_module_names( + self._lora_attn_modules, + self._apply_lora_to_mlp, + self._apply_lora_to_output, + ), + "peft_type": "LORA", + } + checkpoint_dict.update({training.ADAPTER_CONFIG: adapter_config}) + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=intermediate_checkpoint, + adapter_only=self._save_adapter_weights_only, + ) + + def _loss_step( + self, batch: Dict[str, torch.Tensor] + ) -> (torch.Tensor, torch.Tensor): + + # Both are shape [b, s] + tokens, labels = batch["tokens"], batch["labels"] + + # Get the attention mask and position ids from the dataset if they + # exist. Currently, only sample packing in PackedDataset returns these + mask = batch.get("mask", None) # shape [b, s, s] + input_pos = batch.get("input_pos", None) # shape [b, s] + + # run model + logits = self._model(tokens, mask=mask, input_pos=input_pos) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + + # Compute KD loss + with torch.no_grad(): + teacher_logits = self._teacher_model(tokens, mask=mask, input_pos=input_pos) + + # Compute kd loss + kd_loss = self._kd_loss_fn(logits, teacher_logits, labels) + + # Compute loss + loss = self._loss_fn(logits, labels) + + # free logits otherwise it peaks backward memory + del logits + del teacher_logits + + return loss, kd_loss + + def train(self) -> None: + """ + The core training loop. + """ + # clean up before training begins + training.cleanup_before_training() + + _, rank = training.get_world_size_and_rank() + + # zero out the gradients before starting training + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_class_loss = 0 + running_kd_loss = 0 + num_tokens = 0 + + self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Start tracking CUDA memory for active steps for just the first epoch + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx == self.profiler_wait_steps + self.profiler_warmup_steps + ): + torch.cuda.memory._record_memory_history() + + batch = {k: v.to(self._device) for k, v in batch.items()} + num_tokens += batch["tokens"].numel() + + class_loss, kd_loss = self._loss_step(batch) + loss = (1 - self._kd_ratio) * class_loss + self._kd_ratio * kd_loss + loss = loss / self._gradient_accumulation_steps + running_class_loss += class_loss / self._gradient_accumulation_steps + running_kd_loss += kd_loss / self._gradient_accumulation_steps + loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + self._lr_scheduler.step() + # Update the number of steps when the weights are updated + self.global_step += 1 + + class_loss_to_log = running_class_loss.item() + kd_loss_to_log = running_kd_loss.item() + loss_to_log = ( + 1 - self._kd_ratio + ) * class_loss_to_log + self._kd_ratio * kd_loss_to_log + pbar.update(1) + pbar.set_description( + f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "class_loss": class_loss_to_log, + "kd_loss": kd_loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens / time_per_step, + } + if self._log_peak_memory_stats: + log_dict.update( + training.get_memory_stats(device=self._device) + ) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_class_loss = 0 + running_kd_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + # Stop tracking CUDA memory now that active steps are complete + if ( + self._is_rank_zero + and curr_epoch == 0 + and self.profiler_profile_memory + and idx + == self.profiler_wait_steps + + self.profiler_warmup_steps + + self.profiler_active_steps + ): + torch.cuda.memory._record_memory_history(enabled=None) + + # Step the profiler + # Note we are stepping each batch, which might not include optimizer step in the trace + # if the schedule cycle doesn't align with gradient accumulation. + self._profiler.step() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + destroy_process_group() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + if not training.is_distributed(): + raise RuntimeError( + "Distributed finetune recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + if cfg.get("fsdp_cpu_offload", False): + # Utilize all available CPU cores for intra-op parallelism. This provides ~2x + # speed up when benchmarking fused AdamW on CPU + training.set_torch_num_threads() + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + + config.log_config(recipe_name="KDRecipeDistributed", cfg=cfg) + + recipe = KDRecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) From 0f4e92270bcff4e40af886809b46ae3f813afae8 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Mon, 16 Sep 2024 14:08:48 -0700 Subject: [PATCH 22/37] remove todo comment and test activation checkpointing --- recipes/kd_single_device.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/recipes/kd_single_device.py b/recipes/kd_single_device.py index 14b20e33ee..3f0a744861 100644 --- a/recipes/kd_single_device.py +++ b/recipes/kd_single_device.py @@ -265,7 +265,6 @@ def setup(self, cfg: DictConfig) -> None: self._kd_loss_fn = config.instantiate(cfg.kd_loss) if self._compile: self._loss_fn = training.compile_loss(self._loss_fn) - # TODO: compile kd_loss_fn self._kd_loss_fn = training.compile_loss(self._kd_loss_fn) if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": # set num_output_chunks for model @@ -411,7 +410,7 @@ def _setup_model( training.compile_model(model) if enable_activation_checkpointing: - utils.set_activation_checkpointing( + training.set_activation_checkpointing( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) From 0bb49dc02fcb75b09136a80dcdf8472b0a97f4ab Mon Sep 17 00:00:00 2001 From: lindawangg Date: Mon, 16 Sep 2024 20:00:38 -0700 Subject: [PATCH 23/37] qwen2 distributed recipe --- recipes/configs/qwen2/kd_distributed.yaml | 123 ++++++++++++++++++++++ 1 file changed, 123 insertions(+) create mode 100644 recipes/configs/qwen2/kd_distributed.yaml diff --git a/recipes/configs/qwen2/kd_distributed.yaml b/recipes/configs/qwen2/kd_distributed.yaml new file mode 100644 index 0000000000..c3b225fe63 --- /dev/null +++ b/recipes/configs/qwen2/kd_distributed.yaml @@ -0,0 +1,123 @@ +# Config for multi-device knowledge distillation in kd_distributed.py +# using a teacher and student model +# +# This config assumes that you've ran the following commands before launching KD: +# First download the student and teacher models +# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct --ignore-patterns None +# tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None +# +# You get better results using KD if the teacher model has already been fine-tuned on the target dataset: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetuned_distributed --config qwen2/1.5B_lora +# +# To launch on a single device, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 2 kd_distributed --config qwen2/kd_distributed +# +# This config works only for distilling on a single device. + + +# Model Arguments +model: + _component_: torchtune.models.qwen2.lora_qwen2_0_5b + lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] + apply_lora_to_mlp: False + lora_rank: 32 + lora_alpha: 64 + +teacher_model: + _component_: torchtune.models.qwen2.qwen2_1_5b + +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-0.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt + max_seq_len: null + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-0.5B-Instruct + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-0.5B-Instruct-kd + model_type: QWEN2 + +teacher_checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune + checkpoint_files: [ + hf_model_0001_0.pt + ] + recipe_checkpoint: null + output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune + model_type: QWEN2 + +resume_from_checkpoint: False + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 8 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.modules.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +kd_loss: + _component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss +kd_ratio: 0.5 + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 2 + +# Logging +output_dir: /tmp/qwen_kd +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: True + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 From c73857d254569515a1ad1e616503c51262d89166 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Mon, 16 Sep 2024 20:06:18 -0700 Subject: [PATCH 24/37] added to recipe registry --- recipes/kd_distributed.py | 3 ++- torchtune/_recipe_registry.py | 11 +++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/recipes/kd_distributed.py b/recipes/kd_distributed.py index f5e966d267..92754642d0 100644 --- a/recipes/kd_distributed.py +++ b/recipes/kd_distributed.py @@ -24,6 +24,7 @@ from torchtune.datasets import ConcatDataset from torchtune.modules.peft import ( DoRALinear, + get_adapter_params, get_lora_module_names, get_merged_lora_ckpt, load_dora_magnitudes, @@ -426,7 +427,7 @@ def _setup_model( with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(cfg_model) - self.adapter_params = training.get_adapter_params(model) + self.adapter_params = get_adapter_params(model) set_trainable_params(model, self.adapter_params) if self._compile: diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index 13f98c22fe..67ac9084de 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -288,6 +288,17 @@ class Recipe: ], supports_distributed=False, ), + Recipe( + name="kd_distributed", + file_path="kd_distributed.py", + configs=[ + Config( + name="qwen2/kd_distributed", + file_path="qwen2/kd_distributed.yaml", + ), + ], + supports_distributed=True, + ), ] From 59eff44513c6b66db58c0ec436f4e019f7739bd3 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Tue, 17 Sep 2024 16:00:21 -0700 Subject: [PATCH 25/37] fdsp teacher model --- recipes/kd_distributed.py | 82 +++++++++++++++++++++++++++++++++++---- 1 file changed, 75 insertions(+), 7 deletions(-) diff --git a/recipes/kd_distributed.py b/recipes/kd_distributed.py index 92754642d0..cd8291cae4 100644 --- a/recipes/kd_distributed.py +++ b/recipes/kd_distributed.py @@ -246,6 +246,8 @@ def setup(self, cfg: DictConfig) -> None: self._teacher_model = self._setup_teacher_model( model_cfg=cfg.teacher_model, + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), model_state_dict=teacher_checkpoint_dict[training.MODEL_KEY], ) @@ -532,24 +534,90 @@ def _is_layer_name(name: str, module: nn.Module) -> bool: def _setup_teacher_model( self, model_cfg: DictConfig, + fsdp_cpu_offload: bool, + reshard_after_forward: bool, model_state_dict: Dict[str, Any], ) -> nn.Module: - with training.set_default_dtype(self._dtype), self._device: + """ + Model initialization for teacher model has some important considerations: + a. To minimize GPU peak memory, we initialize the model on meta device with + the right dtype + b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since + full state dicts are loaded with ``torch.load(mmap=True)`` + """ + + if self._is_rank_zero: + log.info( + "FSDP enabled. Instantiating teacher model and loading checkpoint on Rank 0 ..." + ) + init_start = time.perf_counter() + + with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(model_cfg) - model.load_state_dict(model_state_dict) + # For FSDP sharding, we can condition on either the module or its name + # Shard conditions should be callables taking name (relative to model root) + # and the module itself and returning a bool on whether to shard the given module + fsdp_shard_conditions = [] + + # Shard transformer decoder layers (or AC-wrapped versions) + # Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) + # But directly using the name is more concise + def _is_layer_fqn(s: str) -> bool: + """ + Return True for layers.i and False for all other module names + Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot + """ + s_list = s.split(".") + return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1]) + + fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)] + + training.shard_model( + model=model, + shard_conditions=fsdp_shard_conditions, + cpu_offload=fsdp_cpu_offload, + reshard_after_forward=reshard_after_forward, + ) + + with training.set_default_dtype(self._dtype), self._device: + for m in model.modules(): + # RoPE is not covered in state dict + if hasattr(m, "rope_init"): + m.rope_init() + + # This method will convert the full model state dict into a sharded state + # dict and load into the model + training.load_from_full_model_state_dict( + model, + model_state_dict, + self._device, + self._is_rank_zero, + strict=True, + cpu_offload=fsdp_cpu_offload, + ) # Put model in eval mode. # Note: This will not disable the dropout applied in SDPA, # see https://github.com/pytorch/pytorch/issues/124464 model.eval() - # Validate model was loaded in with the expected dtype. - training.validate_expected_param_dtype( - model.named_parameters(), dtype=self._dtype - ) + for p in model.parameters(): + p.requires_grad = False + + # Ensure no params and buffers are on meta device + training.validate_no_params_on_meta_device(model) + if self._is_rank_zero: - log.info(f"Teacher model is initialized with precision {self._dtype}.") + log.info( + f"Instantiating teacher model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" + ) + memory_stats = training.get_memory_stats(device=self._device) + training.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + return model def _setup_optimizer( From 85d76bba6df621c203498e76c35d13a7926116fe Mon Sep 17 00:00:00 2001 From: lindawangg Date: Tue, 17 Sep 2024 22:45:59 -0700 Subject: [PATCH 26/37] added kd distributed test --- tests/recipes/test_kd_distributed.py | 290 +++++++++++++++++++++++++++ 1 file changed, 290 insertions(+) create mode 100644 tests/recipes/test_kd_distributed.py diff --git a/tests/recipes/test_kd_distributed.py b/tests/recipes/test_kd_distributed.py new file mode 100644 index 0000000000..5c05d30bda --- /dev/null +++ b/tests/recipes/test_kd_distributed.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import runpy +import sys +from pathlib import Path + +import pytest +import torch +from omegaconf import OmegaConf +from tests.common import TUNE_PATH +from tests.recipes.utils import ( + CKPT_COMPONENT_MAP, + dummy_alpaca_dataset_config, + MODEL_TEST_CONFIGS, + write_hf_ckpt_config, +) +from tests.test_utils import ( + CKPT_MODEL_PATHS, + gen_log_file_name, + get_loss_values_from_metric_logger, + gpu_test, + TOKENIZER_PATHS, +) +from torchtune import config + + +class TestKDDistributedDeviceRecipe: + def _get_test_config_overrides(self, epochs: int = 2): + return [ + "batch_size=4", + "enable_activation_checkpointing=False", + "dataset.train_on_input=False", + "seed=9", + f"epochs={epochs}", + "dtype=fp32", + "max_steps_per_epoch=2", + "optimizer.lr=2e-5", + "log_every_n_steps=1", + "gradient_accumulation_steps=1", + "compile=False", + ] + dummy_alpaca_dataset_config() + + def _fetch_expected_loss_values(self, model_type): + loss_values_map = { + "llama3": [10.3821, 10.3025, 11.0394, 11.7664], + } + return loss_values_map[model_type] + + @pytest.mark.integration_test + @gpu_test(gpu_count=2) + @pytest.mark.parametrize( + "reshard_after_forward", + [ + True, + False, + ], + ) + def test_loss(self, reshard_after_forward, tmpdir, monkeypatch): + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + tokenizer_path = Path(TOKENIZER_PATHS["llama3"]) + + cmd = f""" + tune run --nnodes 1 --nproc_per_node 2 kd_distributed \ + --config qwen2/kd_distributed \ + output_dir={tmpdir} \ + checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ + teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ + teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ + teacher_checkpointer.output_dir={tmpdir} \ + teacher_checkpointer.model_type=LLAMA3 \ + tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + ~tokenizer.merges_file \ + metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ + metric_logger.filename={log_file} \ + kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ + kd_ratio=0.5 \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + teacher_config = [ + "teacher_" + config for config in MODEL_TEST_CONFIGS["llama3"] + ] + + cmd = cmd + self._get_test_config_overrides() + model_config + teacher_config + monkeypatch.setattr(sys, "argv", cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + + loss_values = get_loss_values_from_metric_logger(log_file) + # only take the first loss + num_losses = int(len(loss_values) / 4) # 2 steps per epoch, 2 epochs + loss_values = loss_values[0::num_losses] + expected_loss_values = self._fetch_expected_loss_values("llama3") + print(loss_values) + print(expected_loss_values) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) + + @pytest.mark.integration_test + @gpu_test(gpu_count=2) + def test_training_state_on_resume(self, tmpdir, monkeypatch): + """Test whether the recipe state is correctly updated on resume. Since this + is model agnostic, we should run this on the small model only. The test + consists of three stages: + - Train a model for 2 epochs + - Resume training after epoch 1 + - Make sure final loss matches the expected value of a model successfully resumed from a ckpt + """ + + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + tokenizer_path = Path(TOKENIZER_PATHS["llama3"]) + + # Config file needed for model conversion. + # Create a second copy for training resume + write_hf_ckpt_config(ckpt_dir) + write_hf_ckpt_config(tmpdir) + + # Train for two epochs + cmd_1 = f""" + tune run --nnodes 1 --nproc_per_node 2 kd_distributed \ + --config qwen2/kd_distributed \ + output_dir={tmpdir} \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ + teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ + teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ + teacher_checkpointer.output_dir={tmpdir} \ + teacher_checkpointer.model_type=LLAMA3 \ + tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ + tokenizer.path={tokenizer_path} \ + tokenizer.prompt_template=null \ + ~tokenizer.merges_file \ + metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ + kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ + kd_ratio=0.5 \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + teacher_config = [ + "teacher_" + config for config in MODEL_TEST_CONFIGS["llama3"] + ] + + cmd_1 = ( + cmd_1 + self._get_test_config_overrides() + model_config + teacher_config + ) + monkeypatch.setattr(sys, "argv", cmd_1) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Resume training + cmd_2 = f""" + tune run --nnodes 1 --nproc_per_node 2 kd_distributed \ + --config qwen2/kd_distributed \ + output_dir={tmpdir} \ + checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir={tmpdir} \ + checkpointer.checkpoint_files=[{ckpt_path}]\ + checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")} + checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type=LLAMA3 \ + teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ + teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ + teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ + teacher_checkpointer.output_dir={tmpdir} \ + teacher_checkpointer.model_type=LLAMA3 \ + resume_from_checkpoint=True \ + metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ + metric_logger.filename={log_file} \ + tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ + tokenizer.path={tokenizer_path} \ + tokenizer.prompt_template=null \ + ~tokenizer.merges_file \ + kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ + kd_ratio=0.5 \ + """.split() + cmd_2 = ( + cmd_2 + + self._get_test_config_overrides(epochs=3) + + model_config + + teacher_config + ) + monkeypatch.setattr(sys, "argv", cmd_2) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Second epoch only + expected_loss_values = self._fetch_expected_loss_values("llama3")[2:] + loss_values = get_loss_values_from_metric_logger(log_file) + # only take the first loss + num_losses = int(len(loss_values) / 4) # 2 steps per epoch, 2 epochs + loss_values = loss_values[0::num_losses][:2] + + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) + + @pytest.mark.integration_test + def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): + ckpt_type = "tune" + model_type = "llama3" + ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] + ckpt = model_type + "_" + ckpt_type + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + tokenizer_path = Path(TOKENIZER_PATHS[model_type]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + + cmd = f""" + tune run --nnodes 1 --nproc_per_node 2 kd_distributed \ + --config qwen2/kd_distributed \ + output_dir={tmpdir} \ + checkpointer._component_={ckpt_component} \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + checkpointer.model_type={model_type.upper()} \ + teacher_checkpointer._component_={ckpt_component} \ + teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ + teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ + teacher_checkpointer.output_dir={tmpdir} \ + teacher_checkpointer.model_type={model_type.upper()} \ + tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + ~tokenizer.merges_file \ + metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ + metric_logger.filename={log_file} \ + kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ + kd_ratio=0.5 \ + """.split() + + model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] + teacher_config = [ + "teacher_" + config for config in MODEL_TEST_CONFIGS[model_type] + ] + + cmd = cmd + self._get_test_config_overrides() + model_config + teacher_config + monkeypatch.setattr(sys, "argv", cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + + # Next load both the merged weights in a Llama3 base model + # and the base model weights + trained adapter weights in the LoRA Llama 3 model + # The results of calling forward on dummy inputs should be the same. + inputs = torch.randint(low=0, high=32_000, size=(2, 100)) + + # Build LoRA model for loading base + adapter weights separately + lora_model = config.instantiate(OmegaConf.from_dotlist(model_config).model) + + # Build base llama3 model for loading merged weights + base_llama3_config = MODEL_TEST_CONFIGS[model_type] + llama3_model = config.instantiate( + OmegaConf.from_dotlist(base_llama3_config).model + ) + + # Load base model and trained adapter weights into LoRA model and call fwd + with open(f"{tmpdir}/adapter_1.pt", "rb") as f: + lora_sd = torch.load(f, weights_only=True) + with open(ckpt_path, "rb") as f: + base_model_sd = torch.load(f, weights_only=True) + lora_model.load_state_dict(lora_sd, strict=False) + lora_model.load_state_dict(base_model_sd, strict=False) + baseline_out = lora_model(inputs) + + # Load merged final ckpt directly into 3 and call fwd + with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f: + sd = torch.load(f, weights_only=True) + llama3_model.load_state_dict(sd) + merged_ckpt_out = llama3_model(inputs) + torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5) From dba57c476196e77268a051d438e40e386ae7fbb3 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Wed, 18 Sep 2024 14:46:42 -0700 Subject: [PATCH 27/37] fixed command --- recipes/configs/qwen2/kd_distributed.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes/configs/qwen2/kd_distributed.yaml b/recipes/configs/qwen2/kd_distributed.yaml index c3b225fe63..336ac7b4ff 100644 --- a/recipes/configs/qwen2/kd_distributed.yaml +++ b/recipes/configs/qwen2/kd_distributed.yaml @@ -7,7 +7,7 @@ # tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None # # You get better results using KD if the teacher model has already been fine-tuned on the target dataset: -# tune run --nnodes 1 --nproc_per_node 2 lora_finetuned_distributed --config qwen2/1.5B_lora +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora # # To launch on a single device, run the following command from root: # tune run --nnodes 1 --nproc_per_node 2 kd_distributed --config qwen2/kd_distributed From 44123b91f4d417674eca2866ba260f074c298d34 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Thu, 19 Sep 2024 17:21:49 -0700 Subject: [PATCH 28/37] changed to knowledge_distillation --- recipes/configs/qwen2/kd_single_device.yaml | 97 --- ...> knowledge_distillation_distributed.yaml} | 4 +- recipes/kd_single_device.py | 795 ------------------ ... => knowledge_distillation_distributed.py} | 0 tests/recipes/test_kd_single_device.py | 308 ------- ...est_knowledge_distillation_distributed.py} | 16 +- torchtune/_recipe_registry.py | 8 +- 7 files changed, 14 insertions(+), 1214 deletions(-) delete mode 100644 recipes/configs/qwen2/kd_single_device.yaml rename recipes/configs/qwen2/{kd_distributed.yaml => knowledge_distillation_distributed.yaml} (93%) delete mode 100644 recipes/kd_single_device.py rename recipes/{kd_distributed.py => knowledge_distillation_distributed.py} (100%) delete mode 100644 tests/recipes/test_kd_single_device.py rename tests/recipes/{test_kd_distributed.py => test_knowledge_distillation_distributed.py} (95%) diff --git a/recipes/configs/qwen2/kd_single_device.yaml b/recipes/configs/qwen2/kd_single_device.yaml deleted file mode 100644 index a596ba7aad..0000000000 --- a/recipes/configs/qwen2/kd_single_device.yaml +++ /dev/null @@ -1,97 +0,0 @@ -# Config for single device knowledge distillation in kd_single_device.py -# using a teacher and student model -# -# This config assumes that you've ran the following commands before launching KD: -# First download the student and teacher models -# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct --ignore-patterns None -# tune download Qwen/Qwen2-1.5B-Instruct --output-dir /tmp/Qwen2-1.5B-Instruct --ignore-patterns None -# -# You get better results using KD if the teacher model has already been fine-tuned on the target dataset: -# tune run lora_finetune_single_device --config qwen2/1.5B_lora_single_device -# -# To launch on a single device, run the following command from root: -# tune run kd_single_device --config qwen2/kd_single_device -# -# This config works only for distilling on a single device. - - -# Model Arguments -model: - _component_: torchtune.models.qwen2.lora_qwen2_0_5b - lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] - apply_lora_to_mlp: False - lora_rank: 32 - lora_alpha: 64 - -teacher_model: - _component_: torchtune.models.qwen2.qwen2_1_5b - -tokenizer: - _component_: torchtune.models.qwen2.qwen2_tokenizer - path: /tmp/Qwen2-0.5B-Instruct/vocab.json - merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt - max_seq_len: null - -checkpointer: - _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/Qwen2-0.5B-Instruct - checkpoint_files: [ - model.safetensors - ] - recipe_checkpoint: null - output_dir: /tmp/Qwen2-0.5B-Instruct-kd - model_type: QWEN2 - -teacher_checkpointer: - _component_: torchtune.training.FullModelHFCheckpointer - checkpoint_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune - checkpoint_files: [ - hf_model_0001_0.pt - ] - recipe_checkpoint: null - output_dir: /tmp/Qwen2-1.5B-Instruct-lora-finetune - model_type: QWEN2 - -resume_from_checkpoint: False - -# Dataset and Sampler -dataset: - _component_: torchtune.datasets.alpaca_cleaned_dataset -seed: null -shuffle: True -batch_size: 8 - -# Optimizer and Scheduler -optimizer: - _component_: torch.optim.AdamW - weight_decay: 0.01 - lr: 3e-4 -lr_scheduler: - _component_: torchtune.modules.get_cosine_schedule_with_warmup - num_warmup_steps: 100 - -loss: - _component_: torchtune.modules.loss.CEWithChunkedOutputLoss - -kd_loss: - _component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss -kd_ratio: 0.5 - -# Training -epochs: 1 -max_steps_per_epoch: null -gradient_accumulation_steps: 2 -compile: False - -# Logging -output_dir: /tmp/qwen_kd -metric_logger: - _component_: torchtune.training.metric_logging.DiskLogger - log_dir: ${output_dir} -log_every_n_steps: 1 -log_peak_memory_stats: False - -# Environment -device: cuda -dtype: bf16 -enable_activation_checkpointing: False diff --git a/recipes/configs/qwen2/kd_distributed.yaml b/recipes/configs/qwen2/knowledge_distillation_distributed.yaml similarity index 93% rename from recipes/configs/qwen2/kd_distributed.yaml rename to recipes/configs/qwen2/knowledge_distillation_distributed.yaml index 336ac7b4ff..33002af76b 100644 --- a/recipes/configs/qwen2/kd_distributed.yaml +++ b/recipes/configs/qwen2/knowledge_distillation_distributed.yaml @@ -1,4 +1,4 @@ -# Config for multi-device knowledge distillation in kd_distributed.py +# Config for multi-device knowledge distillation in knowledge_distillation_distributed.py # using a teacher and student model # # This config assumes that you've ran the following commands before launching KD: @@ -10,7 +10,7 @@ # tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora # # To launch on a single device, run the following command from root: -# tune run --nnodes 1 --nproc_per_node 2 kd_distributed --config qwen2/kd_distributed +# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed # # This config works only for distilling on a single device. diff --git a/recipes/kd_single_device.py b/recipes/kd_single_device.py deleted file mode 100644 index 3f0a744861..0000000000 --- a/recipes/kd_single_device.py +++ /dev/null @@ -1,795 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import sys -import time - -from functools import partial -from typing import Any, Dict, Optional, Tuple, Union -from warnings import warn - -import torch -import torchtune.modules.common_utils as common_utils -from omegaconf import DictConfig, ListConfig - -from torch import nn -from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler -from torchtune import config, modules, training, utils -from torchtune.data import padded_collate_sft -from torchtune.datasets import ConcatDataset -from torchtune.modules.peft import ( - get_adapter_params, - get_lora_module_names, - get_merged_lora_ckpt, - load_dora_magnitudes, - set_trainable_params, - validate_missing_and_unexpected_for_lora, -) -from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import DummyProfiler, PROFILER_KEY - -from tqdm import tqdm - -log = utils.get_logger("DEBUG") - - -class KDRecipeSingleDevice(FTRecipeInterface): - """ - Knowledge distillation recipe for dense transformer-based LLMs such as Llama3. This recipe is optimized - for single GPU training. Training on CPU is not supported. - - Features: - - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` - flag. Activation checkpointing helps reduce the memory footprint since we no longer keep - activations in memory and instead recompute them during the backward pass. This is especially - helpful for larger batch sizes when you're memory constrained. But these savings in memory - come at the cost of training performance. In most cases training can slow-down quite a bit as - a result of this activation recomputation. - - - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` - flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In - most cases this should halve the memory footprint of full precision (fp32) training, without - loss in model quality (will depend on the model, training data and other settings). For - GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 - precision are currently not supported.g - - - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is - controlled using the ``gradient_accumulation_steps`` flag. - - Total Batch Size = batch_size * gradient accumulation steps. - - For example: with batch_size=1 and gradient_accumulation_steps=32 we get a total batch size of 32. - - Gradient accumulation is especially useful when you are memory constrained. In this case, - accumulating gradients might give you better training speed than enabling activation - checkpointing. - - - Lower precision optimizers. This recipe supports lower-precision optimizers from the bitsandbytes - library (https://huggingface.co/docs/bitsandbytes/main/en/index). We've tested the recipe with - 8-bit AdamW and Paged AdamW. - - - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of - training. Currently we checkpoint both the adapter weights (trainable params only) and the - complete merged weights (adapter weights added back to the base model). For more details - please take a look at our LoRA tutorial - (https://pytorch.org/torchtune/main/tutorials/lora_finetune.html). - - Optimizer State and recipe state (seed, total_epochs, number of epochs run etc) are - only saved at the end of a given epoch and used in case of resuming training. Resuming - training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is - currently not supported. - - For more details on the checkpointer, please take a look at - our checkpointer deepdive (https://pytorch.org/torchtune/main/tutorials/checkpointer.html). - - - Logging. Terminal, Disk, WandB and TensorBoard are all supported. - - - Gradient Clipping. Gradient clipping is supported using the ``clip_grad_norm`` flag. By default, - ``clip_grad_norm`` is set to ``None``. If you only want to log the grad norm, you can set - ``clip_grad_norm='inf'``. - - For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config - has example commands for how to kick-off training. - - Args: - cfg (DictConfig): OmegaConf object parsed from yaml file - - Raises: - ValueError: If ``dtype`` is set to fp16. - RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. - - """ - - def __init__(self, cfg: DictConfig) -> None: - self._device = utils.get_device(device=cfg.device) - # Reduced precision logic - self._dtype = training.get_dtype(cfg.dtype, device=self._device) - # fp16 precision is explicitly disabled as it is not supported in this - # recipe (for example, no gradient scaling). - if self._dtype == torch.float16: - raise ValueError( - "fp16 precision is not supported in this recipe. Please use fp32 or bf16." - ) - # For CUDA devices, check if the HW supports bf16 if bf16 is specified. - if ( - self._dtype == torch.bfloat16 - and self._device != torch.device("cpu") - and not torch.cuda.is_bf16_supported() - ): - raise RuntimeError("Full bf16 training is not supported on this hardware.") - # logging attributes - self._output_dir = cfg.output_dir - self._log_every_n_steps = cfg.get("log_every_n_steps", 1) - self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) - - # These are public properties which are updated by the checkpoint loader - # when ``resume_from_checkpoint`` is `True` or validated in tests - self.seed = training.set_seed(seed=cfg.seed) - self.epochs_run = 0 - self.total_epochs = cfg.epochs - self.max_steps_per_epoch = cfg.max_steps_per_epoch - self.global_step = 0 - self._resume_from_checkpoint = cfg.resume_from_checkpoint - self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False) - self._gradient_accumulation_steps = cfg.gradient_accumulation_steps - self._clip_grad_norm = cfg.get("clip_grad_norm", None) - self._kd_ratio = cfg.get("kd_ratio", 0.5) - - def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: - """ - Extract the checkpoint state from file and validate. This includes the - base model weights. If resume_from_checkpoint is True, this also includes - the adapter weights and recipe state - """ - self._checkpointer = config.instantiate( - cfg_checkpointer, - resume_from_checkpoint=self._resume_from_checkpoint, - ) - checkpoint_dict = self._checkpointer.load_checkpoint() - - if self._resume_from_checkpoint: - if training.ADAPTER_KEY not in checkpoint_dict: - raise ValueError( - "Adapter weights not found. Please ensure a valid adapter checkpoint is provided." - ) - # _update_recipe_state will throw an exception if the recipe state is not corrctly loaded - # no need to check here - self._update_recipe_state(checkpoint_dict) - return checkpoint_dict - - def load_teacher_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: - """ - Extract the teacher checkpoint state from file. - """ - teacher_checkpointer = config.instantiate( - cfg_checkpointer, - ) - checkpoint_dict = teacher_checkpointer.load_checkpoint() - return checkpoint_dict - - def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: - """ - Updates the recipe state from checkpoint. - """ - try: - self.epochs_run = ckpt_dict[training.EPOCHS_KEY] - - # on mismatch, warn the user and prevent the override - if self.seed != ckpt_dict[training.SEED_KEY]: - warn( - message=( - "Config value for seed does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[training.SEED_KEY]}" - ) - ) - self.seed = ckpt_dict[training.SEED_KEY] - if self.max_steps_per_epoch != ckpt_dict[training.MAX_STEPS_KEY]: - warn( - message=( - "Config value for max_steps_per_epoch does not match the checkpoint value, " - f"using the checkpoint value: {ckpt_dict[training.MAX_STEPS_KEY]}" - ) - ) - self.max_steps_per_epoch = ckpt_dict[training.MAX_STEPS_KEY] - - # on mismatch, warn the user but allow the override - if self.total_epochs != ckpt_dict[training.TOTAL_EPOCHS_KEY]: - warn( - message=( - "Config value for total_epochs does not match the checkpoint value, " - f"using the config value: {self.total_epochs}" - ) - ) - - except KeyError as e: - raise KeyError( - "Checkpoint does not contain the required keys needed for updating recipe state. " - "Are you sure you passed in the right recipe checkpoint?" - ) from e - - def setup(self, cfg: DictConfig) -> None: - """ - Setup the recipe state. This includes recipe state (if resume_from_checkpoint is True), - model, tokenizer, loss, optimizer, learning rate scheduler, sampler, and dataloader. - """ - - self._metric_logger = config.instantiate(cfg.metric_logger) - - # log config with parameter override - self._metric_logger.log_config(cfg) - - self._compile = cfg.compile - checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) - teacher_checkpoint_dict = self.load_teacher_checkpoint( - cfg_checkpointer=cfg.teacher_checkpointer - ) - - common_utils._use_low_cpu_ram = cfg.get("low_cpu_ram", False) - - # set up model - self._model = self._setup_model( - cfg_model=cfg.model, - enable_activation_checkpointing=cfg.enable_activation_checkpointing, - compile_model=cfg.compile, - base_model_state_dict=checkpoint_dict[training.MODEL_KEY], - lora_weights_state_dict=( - checkpoint_dict[training.ADAPTER_KEY] - if self._resume_from_checkpoint - else None - ), - ) - - self._teacher_model = self._setup_teacher_model( - model_cfg=cfg.teacher_model, - model_state_dict=teacher_checkpoint_dict[training.MODEL_KEY], - ) - - self._tokenizer = config.instantiate(cfg.tokenizer) - log.info("Tokenizer is initialized from file.") - - self._optimizer = self._setup_optimizer( - cfg_optimizer=cfg.optimizer, - opt_state_dict=( - checkpoint_dict[training.OPT_KEY] - if self._resume_from_checkpoint - else None - ), - ) - - # initialize loss - self._loss_fn = config.instantiate(cfg.loss) - self._kd_loss_fn = config.instantiate(cfg.kd_loss) - if self._compile: - self._loss_fn = training.compile_loss(self._loss_fn) - self._kd_loss_fn = training.compile_loss(self._kd_loss_fn) - if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": - # set num_output_chunks for model - self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) - self._teacher_model.set_num_output_chunks(self._loss_fn.num_output_chunks) - # assert _loss_fn and _kd_loss_fn have the same num_output_chunks - assert ( - self._loss_fn.num_output_chunks == self._kd_loss_fn.num_output_chunks - ), "Number of output chunks for loss_fn and kd_loss_fn must be the same." - - log.info("Loss is initialized.") - - # Dataloader depends on the tokenizer and loss_fn and should be - # setup after all of these are setup - self._sampler, self._dataloader = self._setup_data( - cfg_dataset=cfg.dataset, - shuffle=cfg.shuffle, - batch_size=cfg.batch_size, - ) - - # Finally update the recipe state which can only be correctly set after all of the - # other components have been initialized and updated. - - # Number of training steps in each epoch depends on the number of batches produced - # by the dataloader and the max_steps_per_epoch param set by the user and is used - # for logging and tracking training state. This should be computed after the dataloader - # has been setup - self._steps_per_epoch = ( - len(self._dataloader) // self._gradient_accumulation_steps - ) - if ( - self.max_steps_per_epoch is not None - and self.max_steps_per_epoch < self._steps_per_epoch - ): - self._steps_per_epoch = self.max_steps_per_epoch - self.global_step = self.epochs_run * self._steps_per_epoch - - # Learning rate scheduler can only be set up after number of steps - # has been computed - self._lr_scheduler = self._setup_lr_scheduler( - cfg_lr_scheduler=cfg.lr_scheduler, - num_training_steps=self.total_epochs * self._steps_per_epoch, - last_epoch=self.global_step - 1, - ) - - # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) - # if cfg is missing profiler key or if `cfg.profiler.enabled = False - self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) - - # Used to ignore labels for loss computation - self.ignore_labels_cache = torch.full( - (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device - ) - - def _setup_profiler( - self, cfg_profiler: Optional[DictConfig] = None - ) -> Union[torch.profiler.profile, DummyProfiler]: - """ - Parses the `profiler` section of top-level `cfg` and sets up profiler - - Args: - cfg_profiler (Optional[DictConfig]): ``profiler`` section of the top-level ``cfg`` (the main config passed to - `recipe.main`). Default None. - - Returns: - profiler: Union[torch.profiler.profile, DummyProfiler] - DummyProfiler is a nullcontext with no-op methods - for `start`, `stop`, and `step` that can be used in place of `torch.profiler.profile` if profiler is not enabled such - that the instrumented training loop does not need to be changed profiling is disabled. - - The profiler config can be provided in configs under the `profiler` key with the following layout: - - .. code-block:: yaml - profiler: - enabled: bool - - #Output directory of trace artifacts - output_dir: str - - #`torch.profiler.ProfilerActivity` types to trace - cpu: bool - cuda: bool - - #Trace options - profile_memory: bool - with_stack: bool - record_shapes: bool - with_flops: bool - - # `torch.profiler.schedule` options: - # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat - wait_steps: int - warmup_steps: int - active_steps: int - num_cycles: int - """ - - # Missing profiler section in config, assume disabled - if cfg_profiler is None: - cfg_profiler = DictConfig({"enabled": False}) - - # Check that component is included and set correctly - if cfg_profiler.get("_component_", None) is None: - cfg_profiler["_component_"] = "torchtune.training.setup_torch_profiler" - else: - assert ( - cfg_profiler.get("_component_") - == "torchtune.training.setup_torch_profiler" - ), "Only torch profiler supported currently: component must be `torchtune.training.setup_torch_profiler`" - - profiler, profiler_cfg = config.instantiate(cfg_profiler) - - log.info(f" Profiler config after instantiation: {profiler_cfg}") - - self.profiler_profile_memory = profiler_cfg.get("profile_memory", False) - if profiler_cfg["enabled"]: - self.profiler_wait_steps = profiler_cfg["wait_steps"] - self.profiler_warmup_steps = profiler_cfg["warmup_steps"] - self.profiler_active_steps = profiler_cfg["active_steps"] - - return profiler - - def _setup_model( - self, - cfg_model: DictConfig, - enable_activation_checkpointing: bool, - compile_model: bool, - base_model_state_dict: Dict[str, Any], - lora_weights_state_dict: Optional[Dict[str, Any]] = None, - ) -> nn.Module: - with training.set_default_dtype(self._dtype), self._device: - model = config.instantiate(cfg_model) - - self._lora_rank = cfg_model.lora_rank - self._lora_alpha = cfg_model.lora_alpha - self._lora_attn_modules = list(cfg_model.lora_attn_modules) - self._apply_lora_to_mlp = cfg_model.apply_lora_to_mlp - self._apply_lora_to_output = getattr(cfg_model, "apply_lora_to_output", False) - self.adapter_params = get_adapter_params(model) - self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()]) - set_trainable_params(model, self.adapter_params) - - if compile_model: - training.compile_model(model) - - if enable_activation_checkpointing: - training.set_activation_checkpointing( - model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} - ) - - base_missing, base_unexpected = model.load_state_dict( - base_model_state_dict, strict=False - ) - # This is for any adapters that need to be initialized after base weights - # have been loaded (e.g. DoRA). - if self._is_dora: - load_dora_magnitudes(model) - if lora_weights_state_dict: - lora_missing, lora_unexpected = model.load_state_dict( - lora_weights_state_dict, strict=False - ) - else: - lora_missing, lora_unexpected = None, None - validate_missing_and_unexpected_for_lora( - lora_attn_modules=self._lora_attn_modules, - apply_lora_to_mlp=self._apply_lora_to_mlp, - apply_lora_to_output=self._apply_lora_to_output, - base_missing=base_missing, - base_unexpected=base_unexpected, - lora_missing=lora_missing, - lora_unexpected=lora_unexpected, - ) - # Validate model adapter params were loaded in with the expected dtype - # TODO (rohan-varma): Further validation to ensure the appropriate base params - # are NF4 vs bf16 based on the quantization config. - training.validate_expected_param_dtype( - self.adapter_params.items(), dtype=self._dtype - ) - - log.info(f"Model is initialized with precision {self._dtype}.") - - if self._device.type == "cuda": - memory_stats = training.get_memory_stats(device=self._device) - training.log_memory_stats(memory_stats) - return model - - def _setup_teacher_model( - self, - model_cfg: DictConfig, - model_state_dict: Dict[str, Any], - ) -> nn.Module: - with training.set_default_dtype(self._dtype), self._device: - model = config.instantiate(model_cfg) - - model.load_state_dict(model_state_dict) - - # Put model in eval mode. - # Note: This will not disable the dropout applied in SDPA, - # see https://github.com/pytorch/pytorch/issues/124464 - model.eval() - - # Validate model was loaded in with the expected dtype. - training.validate_expected_param_dtype( - model.named_parameters(), dtype=self._dtype - ) - log.info(f"Teacher model is initialized with precision {self._dtype}.") - return model - - def _setup_optimizer( - self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None - ) -> Optimizer: - optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) - if opt_state_dict: - optimizer.load_state_dict(opt_state_dict) - - log.info("Optimizer and loss are initialized.") - return optimizer - - def _setup_lr_scheduler( - self, - cfg_lr_scheduler: DictConfig, - num_training_steps: int, - last_epoch: int, - ) -> Optimizer: - lr_scheduler = config.instantiate( - cfg_lr_scheduler, - self._optimizer, - num_training_steps=num_training_steps, - last_epoch=last_epoch, - ) - - log.info("Learning rate scheduler is initialized.") - return lr_scheduler - - def _setup_data( - self, - cfg_dataset: DictConfig, - shuffle: bool, - batch_size: int, - ) -> Tuple[DistributedSampler, DataLoader]: - """ - All data related setup happens here. Currently this recipe only supports - Map-style Datasets which fit into memory and an option for random shuffling. - Samplers, iterable datasets, and streaming datasets are not supported. - """ - if isinstance(cfg_dataset, ListConfig): - datasets = [ - config.instantiate(single_cfg_dataset, self._tokenizer) - for single_cfg_dataset in cfg_dataset - ] - ds = ConcatDataset(datasets=datasets) - packed = False - else: - ds = config.instantiate(cfg_dataset, self._tokenizer) - packed = cfg_dataset.get("packed", False) - - sampler = DistributedSampler( - ds, - num_replicas=1, - rank=0, - shuffle=shuffle, - seed=0, - ) - dataloader = DataLoader( - dataset=ds, - sampler=sampler, - batch_size=batch_size, - collate_fn=( - partial( - padded_collate_sft, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ) - if not packed - else None - ), - ) - - log.info("Dataset and Sampler are initialized.") - - return sampler, dataloader - - def save_checkpoint(self, epoch: int) -> None: - """ - Checkpoint the state of the recipe. The constructed checkpoint state dict - contains the following information: - - Merged weights with key MODEL_KEY - - Adapter weights with key ADAPTER_KEY - - Relevant recipe state if training is not complete - - If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights - - To correctly resume from training, the adapter weights and recipe state must be provided along with the base model weights. - """ - ckpt_dict = {} - - intermediate_checkpoint = epoch + 1 < self.total_epochs - # if training is in-progress, checkpoint the optimizer state as well - if intermediate_checkpoint: - ckpt_dict.update( - { - training.OPT_KEY: self._optimizer.state_dict(), - training.SEED_KEY: self.seed, - training.EPOCHS_KEY: self.epochs_run, - training.TOTAL_EPOCHS_KEY: self.total_epochs, - training.MAX_STEPS_KEY: self.max_steps_per_epoch, - } - ) - - # Move to CPU to avoid a copy on GPU - state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()} - - # Construct the full state dict with LoRA weights merged into base LLM weights - merged_state_dict = get_merged_lora_ckpt( - state_dict, - rank=self._lora_rank, - alpha=self._lora_alpha, - ) - ckpt_dict.update({training.MODEL_KEY: merged_state_dict}) - - # Construct the adapter weights - adapter_key_filter = lambda x: x in self.adapter_params - adapter_state_dict = { - k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k) - } - ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict}) - adapter_config = { - "r": self._lora_rank, - "lora_alpha": self._lora_alpha, - "target_modules": get_lora_module_names( - self._lora_attn_modules, - self._apply_lora_to_mlp, - self._apply_lora_to_output, - ), - "peft_type": "LORA", - } - ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config}) - - self._checkpointer.save_checkpoint( - ckpt_dict, - epoch=epoch, - intermediate_checkpoint=intermediate_checkpoint, - adapter_only=self._save_adapter_weights_only, - ) - - def _loss_step( - self, batch: Dict[str, torch.Tensor] - ) -> (torch.Tensor, torch.Tensor): - - # Both are shape [b, s] - tokens, labels = batch["tokens"], batch["labels"] - - # Get the attention mask and position ids from the dataset if they - # exist. Currently, only sample packing in PackedDataset returns these - mask = batch.get("mask", None) # shape [b, s, s] - input_pos = batch.get("input_pos", None) # shape [b, s] - - # run model - logits = self._model(tokens, mask=mask, input_pos=input_pos) - - # Shift labels to compute loss - # equivalent to doing labels[..., 1:] and logits[..., :-1, :] - # But this way we dont need to slice the logits. We just add an ignore index to labels. - labels = torch.hstack( - (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) - ) - if not isinstance(logits, list): - labels = labels.reshape(-1) - logits = logits.reshape(-1, logits.size(-1)) - - # Compute KD loss - with torch.no_grad(): - teacher_logits = self._teacher_model(tokens, mask=mask, input_pos=input_pos) - - # Compute kd loss - kd_loss = self._kd_loss_fn(logits, teacher_logits, labels) - - # Compute loss - loss = self._loss_fn(logits, labels) - - # free logits otherwise it peaks backward memory - del logits - del teacher_logits - - return loss, kd_loss - - def train(self) -> None: - """ - The core training loop. - """ - - if self._compile: - log.info( - "NOTE: torch.compile is enabled and model is compiled in first forward. Expect a relatively slow first iteration." - ) - - # Initialize tokens count and running loss (for grad accumulation) - t0 = time.perf_counter() - running_class_loss = 0 - running_kd_loss = 0 - num_tokens = 0 - - with self._profiler as prof: - # self.epochs_run should be non-zero when we're resuming from a checkpoint - for curr_epoch in range(self.epochs_run, self.total_epochs): - # Update the sampler to ensure data is correctly shuffled across epochs - # in case shuffle is True - self._sampler.set_epoch(curr_epoch) - - pbar = tqdm(total=self._steps_per_epoch) - for idx, batch in enumerate(self._dataloader): - if ( - self.max_steps_per_epoch is not None - and (idx // self._gradient_accumulation_steps) - == self.max_steps_per_epoch - ): - break - - # Start tracking CUDA memory for active steps for just the first epoch - if ( - curr_epoch == 0 - and self.profiler_profile_memory - and idx == self.profiler_wait_steps + self.profiler_warmup_steps - ): - torch.cuda.memory._record_memory_history() - - batch = {k: v.to(self._device) for k, v in batch.items()} - num_tokens += batch["tokens"].numel() - - class_loss, kd_loss = self._loss_step(batch) - loss = (1 - self._kd_ratio) * class_loss + self._kd_ratio * kd_loss - loss = loss / self._gradient_accumulation_steps - running_class_loss += class_loss / self._gradient_accumulation_steps - running_kd_loss += kd_loss / self._gradient_accumulation_steps - loss.backward() - - # Step with optimizer - if (idx + 1) % self._gradient_accumulation_steps == 0: - if self._clip_grad_norm is not None: - grad_norm = torch.nn.utils.clip_grad_norm_( - self._model.parameters(), - max_norm=float(self._clip_grad_norm), - ) - self._optimizer.step() - self._optimizer.zero_grad(set_to_none=True) - self._lr_scheduler.step() - # Update the number of steps when the weights are updated - self.global_step += 1 - - class_loss_to_log = running_class_loss.item() - kd_loss_to_log = running_kd_loss.item() - loss_to_log = ( - 1 - self._kd_ratio - ) * class_loss_to_log + self._kd_ratio * kd_loss_to_log - pbar.update(1) - pbar.set_description( - f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}" - ) - - # Log per-step metrics - if self.global_step % self._log_every_n_steps == 0: - time_per_step = time.perf_counter() - t0 - log_dict = { - "loss": loss_to_log, - "class_loss": class_loss_to_log, - "kd_loss": kd_loss_to_log, - "lr": self._optimizer.param_groups[0]["lr"], - "tokens_per_second_per_gpu": num_tokens / time_per_step, - } - if ( - self._device.type == "cuda" - and self._log_peak_memory_stats - ): - log_dict.update( - training.get_memory_stats(device=self._device) - ) - if self._clip_grad_norm is not None: - log_dict.update({"grad_norm": grad_norm}) - self._metric_logger.log_dict( - log_dict, - step=self.global_step, - ) - - # Reset running stats for the next step - running_class_loss = 0 - running_kd_loss = 0 - num_tokens = 0 - t0 = time.perf_counter() - - # Stop tracking CUDA memory now that active steps are complete - if ( - curr_epoch == 0 - and self.profiler_profile_memory - and idx - == self.profiler_wait_steps - + self.profiler_warmup_steps - + self.profiler_active_steps - ): - torch.cuda.memory._record_memory_history(enabled=None) - - # Step the profiler - # Note we are stepping each batch, which might not include optimizer step in the trace - # if the schedule cycle doesn't align with gradient accumulation. - prof.step() - - self.epochs_run += 1 - self.save_checkpoint(epoch=curr_epoch) - - def cleanup(self) -> None: - self._metric_logger.close() - - -@config.parse -def recipe_main(cfg: DictConfig) -> None: - """ - Entry point for the recipe. - - Configurable parameters are read in the following order: - - Parameters specified in config (see available configs through ``tune ls``) - - Overwritten by arguments from the command-line - """ - config.log_config(recipe_name="KDRecipeSingleDevice", cfg=cfg) - recipe = KDRecipeSingleDevice(cfg=cfg) - recipe.setup(cfg=cfg) - recipe.train() - recipe.cleanup() - - -if __name__ == "__main__": - sys.exit(recipe_main()) diff --git a/recipes/kd_distributed.py b/recipes/knowledge_distillation_distributed.py similarity index 100% rename from recipes/kd_distributed.py rename to recipes/knowledge_distillation_distributed.py diff --git a/tests/recipes/test_kd_single_device.py b/tests/recipes/test_kd_single_device.py deleted file mode 100644 index b1ad8826b7..0000000000 --- a/tests/recipes/test_kd_single_device.py +++ /dev/null @@ -1,308 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import runpy -import sys -from pathlib import Path - -import pytest -import torch -from omegaconf import OmegaConf -from tests.common import TUNE_PATH -from tests.recipes.utils import ( - CKPT_COMPONENT_MAP, - dummy_alpaca_dataset_config, - MODEL_TEST_CONFIGS, - write_hf_ckpt_config, -) -from tests.test_utils import ( - CKPT_MODEL_PATHS, - gen_log_file_name, - get_loss_values_from_metric_logger, - TOKENIZER_PATHS, -) -from torchtune import config - - -class TestKDSingleDeviceRecipe: - def _get_test_config_overrides(self, dtype_str: str = "fp32", epochs: int = 2): - return [ - "batch_size=8", - "device=cpu", - f"dtype={dtype_str}", - "enable_activation_checkpointing=False", - "dataset.train_on_input=False", - "seed=9", - f"epochs={epochs}", - "max_steps_per_epoch=2", - "optimizer.lr=2e-5", - "log_every_n_steps=1", - "gradient_accumulation_steps=1", - "clip_grad_norm=100", - ] + dummy_alpaca_dataset_config() - - def _fetch_expected_loss_values(self, model_type): - loss_values_map = { - "llama3": [11.0651, 11.0577, 11.0540, 11.7671], - } - return loss_values_map[model_type] - - @pytest.mark.integration_test - @pytest.mark.parametrize("compile", [True, False]) - @pytest.mark.parametrize( - "config, model_type, ckpt_type", - [ - ("qwen2/kd_single_device", "llama3", "tune"), - ], - ) - def test_loss(self, compile, config, model_type, ckpt_type, tmpdir, monkeypatch): - ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] - ckpt = model_type + "_" + ckpt_type - ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) - tokenizer_path = Path(TOKENIZER_PATHS[model_type]) - ckpt_dir = ckpt_path.parent - log_file = gen_log_file_name(tmpdir) - - cmd = f""" - tune run kd_single_device \ - --config {config} \ - output_dir={tmpdir} \ - checkpointer._component_={ckpt_component} \ - checkpointer.checkpoint_dir='{ckpt_dir}' \ - checkpointer.checkpoint_files=[{ckpt_path}] \ - checkpointer.output_dir={tmpdir} \ - checkpointer.model_type={model_type.upper()} \ - teacher_checkpointer._component_={ckpt_component} \ - teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ - teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ - teacher_checkpointer.output_dir={tmpdir} \ - teacher_checkpointer.model_type={model_type.upper()} \ - tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ - tokenizer.path='{tokenizer_path}' \ - tokenizer.prompt_template=null \ - ~tokenizer.merges_file \ - metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ - metric_logger.filename={log_file} \ - compile={compile} \ - kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ - kd_ratio=0.5 \ - """.split() - - model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] - teacher_config = [ - "teacher_" + config for config in MODEL_TEST_CONFIGS[model_type] - ] - - cmd = ( - cmd - + self._get_test_config_overrides(dtype_str="fp32") - + model_config - + teacher_config - ) - monkeypatch.setattr(sys, "argv", cmd) - with pytest.raises(SystemExit, match=""): - runpy.run_path(TUNE_PATH, run_name="__main__") - - # Make sure to clear compile state in between tests - if compile: - torch._dynamo.reset() - - loss_values = get_loss_values_from_metric_logger(log_file) - # only take the first loss - num_losses = int(len(loss_values) / 4) # 2 steps per epoch, 2 epochs - loss_values = loss_values[0::num_losses] - expected_loss_values = self._fetch_expected_loss_values(model_type) - print(loss_values) - print(expected_loss_values) - torch.testing.assert_close( - loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 - ) - - @pytest.mark.integration_test - def test_training_state_on_resume(self, tmpdir, monkeypatch): - """Test whether the recipe state is correctly updated on resume. Since this - is model agnostic, we should run this on the small model only. The test - consists of three stages: - - Train a model for 2 epochs - - Resume training after epoch 1 - - Make sure final loss matches the expected value of a model successfully resumed from a ckpt - """ - - ckpt = "llama3_tune" - ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) - ckpt_dir = ckpt_path.parent - log_file = gen_log_file_name(tmpdir) - tokenizer_path = Path(TOKENIZER_PATHS["llama3"]) - - # Config file needed for model conversion. - # Create a second copy for training resume - write_hf_ckpt_config(ckpt_dir) - write_hf_ckpt_config(tmpdir) - - # Train for two epochs - cmd_1 = f""" - tune run kd_single_device \ - --config qwen2/kd_single_device \ - output_dir={tmpdir} \ - checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ - checkpointer.checkpoint_dir='{ckpt_dir}' \ - checkpointer.checkpoint_files=[{ckpt_path}]\ - checkpointer.output_dir={tmpdir} \ - checkpointer.model_type=LLAMA3 \ - teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ - teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ - teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ - teacher_checkpointer.output_dir={tmpdir} \ - teacher_checkpointer.model_type=LLAMA3 \ - tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ - tokenizer.path={tokenizer_path} \ - tokenizer.prompt_template=null \ - ~tokenizer.merges_file \ - metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ - kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ - kd_ratio=0.5 \ - """.split() - - model_config = MODEL_TEST_CONFIGS["llama3_lora"] - teacher_config = [ - "teacher_" + config for config in MODEL_TEST_CONFIGS["llama3"] - ] - - cmd_1 = ( - cmd_1 + self._get_test_config_overrides() + model_config + teacher_config - ) - monkeypatch.setattr(sys, "argv", cmd_1) - with pytest.raises(SystemExit, match=""): - runpy.run_path(TUNE_PATH, run_name="__main__") - - # Resume training - cmd_2 = f""" - tune run kd_single_device \ - --config qwen2/kd_single_device \ - output_dir={tmpdir} \ - checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ - checkpointer.checkpoint_dir={tmpdir} \ - checkpointer.checkpoint_files=[{ckpt_path}]\ - checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")} - checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} - checkpointer.output_dir={tmpdir} \ - checkpointer.model_type=LLAMA3 \ - teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ - teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ - teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ - teacher_checkpointer.output_dir={tmpdir} \ - teacher_checkpointer.model_type=LLAMA3 \ - resume_from_checkpoint=True \ - metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ - metric_logger.filename={log_file} \ - tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ - tokenizer.path={tokenizer_path} \ - tokenizer.prompt_template=null \ - ~tokenizer.merges_file \ - kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ - kd_ratio=0.5 \ - """.split() - cmd_2 = ( - cmd_2 - + self._get_test_config_overrides(epochs=3) - + model_config - + teacher_config - ) - monkeypatch.setattr(sys, "argv", cmd_2) - with pytest.raises(SystemExit, match=""): - runpy.run_path(TUNE_PATH, run_name="__main__") - - # Second epoch only - expected_loss_values = self._fetch_expected_loss_values("llama3")[2:] - loss_values = get_loss_values_from_metric_logger(log_file) - # only take the first loss - num_losses = int(len(loss_values) / 4) # 2 steps per epoch, 2 epochs - loss_values = loss_values[0::num_losses][:2] - - torch.testing.assert_close( - loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 - ) - - @pytest.mark.integration_test - def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): - ckpt_type = "tune" - model_type = "llama3" - ckpt_component = CKPT_COMPONENT_MAP[ckpt_type] - ckpt = model_type + "_" + ckpt_type - ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) - tokenizer_path = Path(TOKENIZER_PATHS[model_type]) - ckpt_dir = ckpt_path.parent - log_file = gen_log_file_name(tmpdir) - - cmd = f""" - tune run kd_single_device \ - --config qwen2/kd_single_device \ - output_dir={tmpdir} \ - checkpointer._component_={ckpt_component} \ - checkpointer.checkpoint_dir='{ckpt_dir}' \ - checkpointer.checkpoint_files=[{ckpt_path}] \ - checkpointer.output_dir={tmpdir} \ - checkpointer.model_type={model_type.upper()} \ - teacher_checkpointer._component_={ckpt_component} \ - teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ - teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ - teacher_checkpointer.output_dir={tmpdir} \ - teacher_checkpointer.model_type={model_type.upper()} \ - tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ - tokenizer.path='{tokenizer_path}' \ - tokenizer.prompt_template=null \ - ~tokenizer.merges_file \ - metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ - metric_logger.filename={log_file} \ - kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ - kd_ratio=0.5 \ - """.split() - - model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] - teacher_config = [ - "teacher_" + config for config in MODEL_TEST_CONFIGS[model_type] - ] - - cmd = ( - cmd - + self._get_test_config_overrides(dtype_str="fp32") - + model_config - + teacher_config - ) - monkeypatch.setattr(sys, "argv", cmd) - with pytest.raises(SystemExit, match=""): - runpy.run_path(TUNE_PATH, run_name="__main__") - - # Next load both the merged weights in a Llama3 base model - # and the base model weights + trained adapter weights in the LoRA Llama 3 model - # The results of calling forward on dummy inputs should be the same. - inputs = torch.randint(low=0, high=32_000, size=(2, 100)) - - # Build LoRA model for loading base + adapter weights separately - lora_model = config.instantiate(OmegaConf.from_dotlist(model_config).model) - - # Build base llama3 model for loading merged weights - base_llama3_config = MODEL_TEST_CONFIGS[model_type] - llama3_model = config.instantiate( - OmegaConf.from_dotlist(base_llama3_config).model - ) - - # Load base model and trained adapter weights into LoRA model and call fwd - with open(f"{tmpdir}/adapter_1.pt", "rb") as f: - lora_sd = torch.load(f, weights_only=True) - with open(ckpt_path, "rb") as f: - base_model_sd = torch.load(f, weights_only=True) - lora_model.load_state_dict(lora_sd, strict=False) - lora_model.load_state_dict(base_model_sd, strict=False) - baseline_out = lora_model(inputs) - - # Load merged final ckpt directly into 3 and call fwd - with open(f"{tmpdir}/torchtune_model_1.pt", "rb") as f: - sd = torch.load(f, weights_only=True) - llama3_model.load_state_dict(sd) - merged_ckpt_out = llama3_model(inputs) - torch.testing.assert_close(baseline_out, merged_ckpt_out, rtol=1e-5, atol=1e-5) diff --git a/tests/recipes/test_kd_distributed.py b/tests/recipes/test_knowledge_distillation_distributed.py similarity index 95% rename from tests/recipes/test_kd_distributed.py rename to tests/recipes/test_knowledge_distillation_distributed.py index 5c05d30bda..d27d6c48d2 100644 --- a/tests/recipes/test_kd_distributed.py +++ b/tests/recipes/test_knowledge_distillation_distributed.py @@ -68,8 +68,8 @@ def test_loss(self, reshard_after_forward, tmpdir, monkeypatch): tokenizer_path = Path(TOKENIZER_PATHS["llama3"]) cmd = f""" - tune run --nnodes 1 --nproc_per_node 2 kd_distributed \ - --config qwen2/kd_distributed \ + tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \ + --config qwen2/knowledge_distillation_distributed \ output_dir={tmpdir} \ checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -135,8 +135,8 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): # Train for two epochs cmd_1 = f""" - tune run --nnodes 1 --nproc_per_node 2 kd_distributed \ - --config qwen2/kd_distributed \ + tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \ + --config qwen2/knowledge_distillation_distributed \ output_dir={tmpdir} \ checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ @@ -170,8 +170,8 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): # Resume training cmd_2 = f""" - tune run --nnodes 1 --nproc_per_node 2 kd_distributed \ - --config qwen2/kd_distributed \ + tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \ + --config qwen2/knowledge_distillation_distributed \ output_dir={tmpdir} \ checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir={tmpdir} \ @@ -227,8 +227,8 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): log_file = gen_log_file_name(tmpdir) cmd = f""" - tune run --nnodes 1 --nproc_per_node 2 kd_distributed \ - --config qwen2/kd_distributed \ + tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \ + --config qwen2/knowledge_distillation_distributed \ output_dir={tmpdir} \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir='{ckpt_dir}' \ diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index b9a1a3f983..84d0879b69 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -293,12 +293,12 @@ class Recipe: supports_distributed=False, ), Recipe( - name="kd_distributed", - file_path="kd_distributed.py", + name="knowledge_distillation_distributed", + file_path="knowledge_distillation_distributed.py", configs=[ Config( - name="qwen2/kd_distributed", - file_path="qwen2/kd_distributed.yaml", + name="qwen2/knowledge_distillation_distributed", + file_path="qwen2/knowledge_distillation_distributed.yaml", ), ], supports_distributed=True, From a04244d79a228866a5d51576ebcf2a65b17d9c7e Mon Sep 17 00:00:00 2001 From: lindawangg Date: Thu, 19 Sep 2024 19:20:37 -0700 Subject: [PATCH 29/37] cleaned up tests --- tests/recipes/test_knowledge_distillation_distributed.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/recipes/test_knowledge_distillation_distributed.py b/tests/recipes/test_knowledge_distillation_distributed.py index d27d6c48d2..95e02bb812 100644 --- a/tests/recipes/test_knowledge_distillation_distributed.py +++ b/tests/recipes/test_knowledge_distillation_distributed.py @@ -87,8 +87,6 @@ def test_loss(self, reshard_after_forward, tmpdir, monkeypatch): ~tokenizer.merges_file \ metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ metric_logger.filename={log_file} \ - kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ - kd_ratio=0.5 \ """.split() model_config = MODEL_TEST_CONFIGS["llama3_lora"] @@ -153,8 +151,6 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): tokenizer.prompt_template=null \ ~tokenizer.merges_file \ metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ - kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ - kd_ratio=0.5 \ """.split() model_config = MODEL_TEST_CONFIGS["llama3_lora"] @@ -192,8 +188,6 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): tokenizer.path={tokenizer_path} \ tokenizer.prompt_template=null \ ~tokenizer.merges_file \ - kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ - kd_ratio=0.5 \ """.split() cmd_2 = ( cmd_2 @@ -246,8 +240,6 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): ~tokenizer.merges_file \ metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ metric_logger.filename={log_file} \ - kd_loss._component_=torchtune.modules.loss.ForwardKLWithChunkedOutputLoss \ - kd_ratio=0.5 \ """.split() model_config = MODEL_TEST_CONFIGS[model_type + "_lora"] From 1ff9934db136f91c4c06f5de8760ba638e6c7889 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Fri, 20 Sep 2024 11:03:06 -0700 Subject: [PATCH 30/37] added gpu test --- tests/recipes/test_knowledge_distillation_distributed.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/recipes/test_knowledge_distillation_distributed.py b/tests/recipes/test_knowledge_distillation_distributed.py index 95e02bb812..a603528dcb 100644 --- a/tests/recipes/test_knowledge_distillation_distributed.py +++ b/tests/recipes/test_knowledge_distillation_distributed.py @@ -210,6 +210,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): ) @pytest.mark.integration_test + @gpu_test(gpu_count=2) def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): ckpt_type = "tune" model_type = "llama3" From 307791df0ce6579e5e052cd8138ce986bd48c339 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Mon, 14 Oct 2024 21:39:28 -0700 Subject: [PATCH 31/37] added llama3 config and addressed comments --- .../knowledge_distillation_distributed.yaml | 130 ++++++++++++++++++ .../knowledge_distillation_distributed.yaml | 6 +- recipes/knowledge_distillation_distributed.py | 32 +++-- ...test_knowledge_distillation_distributed.py | 79 +---------- torchtune/_recipe_registry.py | 4 + 5 files changed, 158 insertions(+), 93 deletions(-) create mode 100644 recipes/configs/llama3_2/knowledge_distillation_distributed.yaml diff --git a/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml b/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml new file mode 100644 index 0000000000..f2259552bb --- /dev/null +++ b/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml @@ -0,0 +1,130 @@ +# Config for multi-device knowledge distillation in knowledge_distillation_distributed.py +# using a teacher and student model +# +# This config assumes that you've ran the following commands before launching KD: +# First download the student and teacher models +# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth" +# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth" +# +# You get better results using KD if the teacher model has already been fine-tuned on the target dataset: +# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora +# +# To launch on a 2 devices, run the following command from root: +# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/knowledge_distillation_distributed +# +# This config works best for distilling on 2+ devices. + + +# Model Arguments +model: + _component_: torchtune.models.llama3_2.lora_llama3_2_1b + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + apply_lora_to_output: False + lora_rank: 64 + lora_alpha: 128 + lora_dropout: 0.0 + +teacher_model: + _component_: torchtune.models.llama3_1.llama3_1_8b + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model + max_seq_len: null + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/ + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Llama-3.2-1B-Instruct/ + model_type: LLAMA3 +resume_from_checkpoint: False +save_adapter_weights_only: False + +# Teacher checkpoint +teacher_checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors + ] + recipe_checkpoint: null + output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/ + model_type: LLAMA3 + +# Dataset and Sampler +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset +seed: null +shuffle: True +batch_size: 4 + +# Optimizer and Scheduler +optimizer: + _component_: torch.optim.AdamW + fused: True + weight_decay: 0.01 + lr: 3e-4 +lr_scheduler: + _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup + num_warmup_steps: 100 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss + +kd_loss: + _component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss +kd_ratio: 0.5 + +# Training +epochs: 1 +max_steps_per_epoch: null +gradient_accumulation_steps: 32 + +# Logging +output_dir: /tmp/kd_output +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir} +log_every_n_steps: 1 +log_peak_memory_stats: False + +# Environment +device: cuda +dtype: bf16 +enable_activation_checkpointing: False + +# Show case the usage of pytorch profiler +# Set enabled to False as it's only needed for debugging training +profiler: + _component_: torchtune.training.setup_torch_profiler + + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 5 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/configs/qwen2/knowledge_distillation_distributed.yaml b/recipes/configs/qwen2/knowledge_distillation_distributed.yaml index 33002af76b..6c922bbc79 100644 --- a/recipes/configs/qwen2/knowledge_distillation_distributed.yaml +++ b/recipes/configs/qwen2/knowledge_distillation_distributed.yaml @@ -9,10 +9,10 @@ # You get better results using KD if the teacher model has already been fine-tuned on the target dataset: # tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora # -# To launch on a single device, run the following command from root: +# To launch on a 2 devices, run the following command from root: # tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed # -# This config works only for distilling on a single device. +# This config works best for distilling on 2+ devices. # Model Arguments @@ -93,7 +93,7 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 -enable_activation_checkpointing: True +enable_activation_checkpointing: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py index cd8291cae4..d702c03b2e 100644 --- a/recipes/knowledge_distillation_distributed.py +++ b/recipes/knowledge_distillation_distributed.py @@ -46,6 +46,12 @@ class KDRecipeDistributed(FTRecipeInterface): for single GPU training. Training on CPU is not supported. Features: + - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states + is supported via ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is + done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config + ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). + DDP is currently not supported. Training on CPU is not supported. + - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep activations in memory and instead recompute them during the backward pass. This is especially @@ -58,7 +64,7 @@ class KDRecipeDistributed(FTRecipeInterface): most cases this should halve the memory footprint of full precision (fp32) training, without loss in model quality (will depend on the model, training data and other settings). For GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 - precision are currently not supported.g + precision are currently not supported. - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is controlled using the ``gradient_accumulation_steps`` flag. @@ -71,10 +77,6 @@ class KDRecipeDistributed(FTRecipeInterface): accumulating gradients might give you better training speed than enabling activation checkpointing. - - Lower precision optimizers. This recipe supports lower-precision optimizers from the bitsandbytes - library (https://huggingface.co/docs/bitsandbytes/main/en/index). We've tested the recipe with - 8-bit AdamW and Paged AdamW. - - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of training. Currently we checkpoint both the adapter weights (trainable params only) and the complete merged weights (adapter weights added back to the base model). For more details @@ -503,7 +505,7 @@ def _is_layer_name(name: str, module: nn.Module) -> bool: is_dora = False for m in model.modules(): if hasattr(m, "initialize_dora_magnitude"): - is_dora = (True,) + is_dora = True m.initialize_dora_magnitude() if is_dora: load_dora_magnitudes(model) @@ -687,14 +689,16 @@ def _setup_data( dataset=ds, batch_size=batch_size, sampler=sampler, - collate_fn=partial( - padded_collate_sft, - padding_idx=self._tokenizer.pad_id, - ignore_idx=self._loss_fn.ignore_index, - ) - if not packed - else partial( - padded_collate_packed, + collate_fn=( + partial( + padded_collate_sft, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else partial( + padded_collate_packed, + ) ), ) diff --git a/tests/recipes/test_knowledge_distillation_distributed.py b/tests/recipes/test_knowledge_distillation_distributed.py index a603528dcb..50fcf6555e 100644 --- a/tests/recipes/test_knowledge_distillation_distributed.py +++ b/tests/recipes/test_knowledge_distillation_distributed.py @@ -51,64 +51,6 @@ def _fetch_expected_loss_values(self, model_type): } return loss_values_map[model_type] - @pytest.mark.integration_test - @gpu_test(gpu_count=2) - @pytest.mark.parametrize( - "reshard_after_forward", - [ - True, - False, - ], - ) - def test_loss(self, reshard_after_forward, tmpdir, monkeypatch): - ckpt = "llama3_tune" - ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) - ckpt_dir = ckpt_path.parent - log_file = gen_log_file_name(tmpdir) - tokenizer_path = Path(TOKENIZER_PATHS["llama3"]) - - cmd = f""" - tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \ - --config qwen2/knowledge_distillation_distributed \ - output_dir={tmpdir} \ - checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ - checkpointer.checkpoint_dir='{ckpt_dir}' \ - checkpointer.checkpoint_files=[{ckpt_path}] \ - checkpointer.output_dir={tmpdir} \ - checkpointer.model_type=LLAMA3 \ - teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ - teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ - teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ - teacher_checkpointer.output_dir={tmpdir} \ - teacher_checkpointer.model_type=LLAMA3 \ - tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ - tokenizer.path='{tokenizer_path}' \ - tokenizer.prompt_template=null \ - ~tokenizer.merges_file \ - metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ - metric_logger.filename={log_file} \ - """.split() - - model_config = MODEL_TEST_CONFIGS["llama3_lora"] - teacher_config = [ - "teacher_" + config for config in MODEL_TEST_CONFIGS["llama3"] - ] - - cmd = cmd + self._get_test_config_overrides() + model_config + teacher_config - monkeypatch.setattr(sys, "argv", cmd) - runpy.run_path(TUNE_PATH, run_name="__main__") - - loss_values = get_loss_values_from_metric_logger(log_file) - # only take the first loss - num_losses = int(len(loss_values) / 4) # 2 steps per epoch, 2 epochs - loss_values = loss_values[0::num_losses] - expected_loss_values = self._fetch_expected_loss_values("llama3") - print(loss_values) - print(expected_loss_values) - torch.testing.assert_close( - loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 - ) - @pytest.mark.integration_test @gpu_test(gpu_count=2) def test_training_state_on_resume(self, tmpdir, monkeypatch): @@ -134,23 +76,18 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): # Train for two epochs cmd_1 = f""" tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \ - --config qwen2/knowledge_distillation_distributed \ + --config llama3_2/knowledge_distillation_distributed \ output_dir={tmpdir} \ checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}]\ checkpointer.output_dir={tmpdir} \ - checkpointer.model_type=LLAMA3 \ teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ teacher_checkpointer.output_dir={tmpdir} \ - teacher_checkpointer.model_type=LLAMA3 \ - tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path={tokenizer_path} \ tokenizer.prompt_template=null \ - ~tokenizer.merges_file \ - metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ """.split() model_config = MODEL_TEST_CONFIGS["llama3_lora"] @@ -167,7 +104,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): # Resume training cmd_2 = f""" tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \ - --config qwen2/knowledge_distillation_distributed \ + --config llama3_2/knowledge_distillation_distributed \ output_dir={tmpdir} \ checkpointer=torchtune.training.FullModelTorchTuneCheckpointer \ checkpointer.checkpoint_dir={tmpdir} \ @@ -175,19 +112,14 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch): checkpointer.adapter_checkpoint={os.path.join(tmpdir, "adapter_0.pt")} checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")} checkpointer.output_dir={tmpdir} \ - checkpointer.model_type=LLAMA3 \ teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ teacher_checkpointer.output_dir={tmpdir} \ - teacher_checkpointer.model_type=LLAMA3 \ resume_from_checkpoint=True \ - metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ metric_logger.filename={log_file} \ - tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path={tokenizer_path} \ tokenizer.prompt_template=null \ - ~tokenizer.merges_file \ """.split() cmd_2 = ( cmd_2 @@ -223,23 +155,18 @@ def test_save_and_load_merged_weights(self, tmpdir, monkeypatch): cmd = f""" tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \ - --config qwen2/knowledge_distillation_distributed \ + --config llama3_2/knowledge_distillation_distributed \ output_dir={tmpdir} \ checkpointer._component_={ckpt_component} \ checkpointer.checkpoint_dir='{ckpt_dir}' \ checkpointer.checkpoint_files=[{ckpt_path}] \ checkpointer.output_dir={tmpdir} \ - checkpointer.model_type={model_type.upper()} \ teacher_checkpointer._component_={ckpt_component} \ teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ teacher_checkpointer.output_dir={tmpdir} \ - teacher_checkpointer.model_type={model_type.upper()} \ - tokenizer._component_=torchtune.models.llama3.llama3_tokenizer \ tokenizer.path='{tokenizer_path}' \ tokenizer.prompt_template=null \ - ~tokenizer.merges_file \ - metric_logger._component_=torchtune.training.metric_logging.DiskLogger \ metric_logger.filename={log_file} \ """.split() diff --git a/torchtune/_recipe_registry.py b/torchtune/_recipe_registry.py index ea4b56bc6d..ef5440c3ca 100644 --- a/torchtune/_recipe_registry.py +++ b/torchtune/_recipe_registry.py @@ -383,6 +383,10 @@ class Recipe: name="qwen2/knowledge_distillation_distributed", file_path="qwen2/knowledge_distillation_distributed.yaml", ), + Config( + name="llama3_2/knowledge_distillation_distributed", + file_path="llama3_2/knowledge_distillation_distributed.yaml", + ), ], supports_distributed=True, ), From fefc24d7988f7a37c8f10c342e401c28bb94b713 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Tue, 15 Oct 2024 16:30:31 -0700 Subject: [PATCH 32/37] added custom sharding layers --- recipes/knowledge_distillation_distributed.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py index d702c03b2e..664a6c372f 100644 --- a/recipes/knowledge_distillation_distributed.py +++ b/recipes/knowledge_distillation_distributed.py @@ -8,7 +8,7 @@ import time from functools import partial -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from warnings import warn import torch @@ -248,6 +248,7 @@ def setup(self, cfg: DictConfig) -> None: self._teacher_model = self._setup_teacher_model( model_cfg=cfg.teacher_model, + custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), model_state_dict=teacher_checkpoint_dict[training.MODEL_KEY], @@ -536,6 +537,7 @@ def _is_layer_name(name: str, module: nn.Module) -> bool: def _setup_teacher_model( self, model_cfg: DictConfig, + custom_sharded_layers: Optional[List[str]], fsdp_cpu_offload: bool, reshard_after_forward: bool, model_state_dict: Dict[str, Any], @@ -557,6 +559,9 @@ def _setup_teacher_model( with training.set_default_dtype(self._dtype), torch.device("meta"): model = config.instantiate(model_cfg) + if self._compile: + training.compile_model(model, verbose=self._is_rank_zero) + # For FSDP sharding, we can condition on either the module or its name # Shard conditions should be callables taking name (relative to model root) # and the module itself and returning a bool on whether to shard the given module @@ -575,6 +580,9 @@ def _is_layer_fqn(s: str) -> bool: fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)] + if custom_sharded_layers: + fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers] + training.shard_model( model=model, shard_conditions=fsdp_shard_conditions, From 46473ee56a8878f3203c9a2fd4bf36a9865042b3 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Tue, 22 Oct 2024 10:56:30 -0700 Subject: [PATCH 33/37] add test_loss back in --- ...test_knowledge_distillation_distributed.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/recipes/test_knowledge_distillation_distributed.py b/tests/recipes/test_knowledge_distillation_distributed.py index 50fcf6555e..9e1ba08509 100644 --- a/tests/recipes/test_knowledge_distillation_distributed.py +++ b/tests/recipes/test_knowledge_distillation_distributed.py @@ -51,6 +51,52 @@ def _fetch_expected_loss_values(self, model_type): } return loss_values_map[model_type] + @pytest.mark.integration_test + @gpu_test(gpu_count=2) + def test_loss(self, tmpdir, monkeypatch): + ckpt = "llama3_tune" + ckpt_path = Path(CKPT_MODEL_PATHS[ckpt]) + ckpt_dir = ckpt_path.parent + log_file = gen_log_file_name(tmpdir) + tokenizer_path = Path(TOKENIZER_PATHS["llama3"]) + + cmd = f""" + tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed \ + --config llama3_2/knowledge_distillation_distributed \ + output_dir={tmpdir} \ + checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ + checkpointer.checkpoint_dir='{ckpt_dir}' \ + checkpointer.checkpoint_files=[{ckpt_path}] \ + checkpointer.output_dir={tmpdir} \ + teacher_checkpointer._component_=torchtune.training.FullModelTorchTuneCheckpointer \ + teacher_checkpointer.checkpoint_dir='{ckpt_dir}' \ + teacher_checkpointer.checkpoint_files=[{ckpt_path}] \ + teacher_checkpointer.output_dir={tmpdir} \ + tokenizer.path='{tokenizer_path}' \ + tokenizer.prompt_template=null \ + metric_logger.filename={log_file} \ + """.split() + + model_config = MODEL_TEST_CONFIGS["llama3_lora"] + teacher_config = [ + "teacher_" + config for config in MODEL_TEST_CONFIGS["llama3"] + ] + + cmd = cmd + self._get_test_config_overrides() + model_config + teacher_config + monkeypatch.setattr(sys, "argv", cmd) + runpy.run_path(TUNE_PATH, run_name="__main__") + + loss_values = get_loss_values_from_metric_logger(log_file) + # only take the first loss + num_losses = int(len(loss_values) / 4) # 2 steps per epoch, 2 epochs + loss_values = loss_values[0::num_losses] + expected_loss_values = self._fetch_expected_loss_values("llama3") + print(loss_values) + print(expected_loss_values) + torch.testing.assert_close( + loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 + ) + @pytest.mark.integration_test @gpu_test(gpu_count=2) def test_training_state_on_resume(self, tmpdir, monkeypatch): From 557396ec014d75259e17706ff5d0785569d658ce Mon Sep 17 00:00:00 2001 From: lindawangg Date: Wed, 23 Oct 2024 20:41:28 -0700 Subject: [PATCH 34/37] rebase --- tests/recipes/test_knowledge_distillation_distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/recipes/test_knowledge_distillation_distributed.py b/tests/recipes/test_knowledge_distillation_distributed.py index 9e1ba08509..ffffb47a3c 100644 --- a/tests/recipes/test_knowledge_distillation_distributed.py +++ b/tests/recipes/test_knowledge_distillation_distributed.py @@ -47,7 +47,7 @@ def _get_test_config_overrides(self, epochs: int = 2): def _fetch_expected_loss_values(self, model_type): loss_values_map = { - "llama3": [10.3821, 10.3025, 11.0394, 11.7664], + "llama3": [11.8316, 11.7520, 11.7642, 11.7664], } return loss_values_map[model_type] From 53c47ba9047c86fe079dd54dfa9fe27ad6f085d7 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Fri, 25 Oct 2024 15:39:33 -0700 Subject: [PATCH 35/37] grad accumulation changes --- recipes/knowledge_distillation_distributed.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py index 664a6c372f..ff6b51a620 100644 --- a/recipes/knowledge_distillation_distributed.py +++ b/recipes/knowledge_distillation_distributed.py @@ -885,23 +885,31 @@ def train(self) -> None: batch = {k: v.to(self._device) for k, v in batch.items()} num_tokens += batch["tokens"].numel() + # Calculate the number of unmasked tokens in the current batch + # and increment the total number of tokens seen in the step + current_num_tokens = ( + batch["labels"] != self._loss_fn.ignore_index + ).sum() + num_tokens += current_num_tokens + class_loss, kd_loss = self._loss_step(batch) - loss = (1 - self._kd_ratio) * class_loss + self._kd_ratio * kd_loss - loss = loss / self._gradient_accumulation_steps - running_class_loss += class_loss / self._gradient_accumulation_steps - running_kd_loss += kd_loss / self._gradient_accumulation_steps - loss.backward() + running_class_loss += class_loss * current_num_tokens + running_kd_loss += kd_loss * current_num_tokens # Step with optimizer if (idx + 1) % self._gradient_accumulation_steps == 0: + class_loss = running_class_loss / num_tokens + kd_loss = running_kd_loss / num_tokens + loss = (1 - self._kd_ratio) * class_loss + self._kd_ratio * kd_loss + loss.backward() self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self._lr_scheduler.step() # Update the number of steps when the weights are updated self.global_step += 1 - class_loss_to_log = running_class_loss.item() - kd_loss_to_log = running_kd_loss.item() + class_loss_to_log = class_loss.item() + kd_loss_to_log = kd_loss.item() loss_to_log = ( 1 - self._kd_ratio ) * class_loss_to_log + self._kd_ratio * kd_loss_to_log From f193d028b0c83976c2cca30a070ac406354e8653 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Fri, 25 Oct 2024 17:20:52 -0700 Subject: [PATCH 36/37] remove extra num_tokens --- recipes/knowledge_distillation_distributed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py index ff6b51a620..87342b9def 100644 --- a/recipes/knowledge_distillation_distributed.py +++ b/recipes/knowledge_distillation_distributed.py @@ -883,7 +883,6 @@ def train(self) -> None: torch.cuda.memory._record_memory_history() batch = {k: v.to(self._device) for k, v in batch.items()} - num_tokens += batch["tokens"].numel() # Calculate the number of unmasked tokens in the current batch # and increment the total number of tokens seen in the step From 227e69d9a183f3e1eade2aa11836c882e2a8f017 Mon Sep 17 00:00:00 2001 From: lindawangg Date: Mon, 28 Oct 2024 12:55:07 -0700 Subject: [PATCH 37/37] addressed comments --- .../knowledge_distillation_distributed.yaml | 2 +- .../knowledge_distillation_distributed.yaml | 2 +- recipes/knowledge_distillation_distributed.py | 61 ++++++------------- ...test_knowledge_distillation_distributed.py | 4 +- 4 files changed, 22 insertions(+), 47 deletions(-) diff --git a/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml b/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml index f2259552bb..1cc864b900 100644 --- a/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml +++ b/recipes/configs/llama3_2/knowledge_distillation_distributed.yaml @@ -9,7 +9,7 @@ # You get better results using KD if the teacher model has already been fine-tuned on the target dataset: # tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora # -# To launch on a 2 devices, run the following command from root: +# To launch on 2 devices, run the following command from root: # tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/knowledge_distillation_distributed # # This config works best for distilling on 2+ devices. diff --git a/recipes/configs/qwen2/knowledge_distillation_distributed.yaml b/recipes/configs/qwen2/knowledge_distillation_distributed.yaml index 6c922bbc79..9727860ca7 100644 --- a/recipes/configs/qwen2/knowledge_distillation_distributed.yaml +++ b/recipes/configs/qwen2/knowledge_distillation_distributed.yaml @@ -9,7 +9,7 @@ # You get better results using KD if the teacher model has already been fine-tuned on the target dataset: # tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/1.5B_lora # -# To launch on a 2 devices, run the following command from root: +# To launch on 2 devices, run the following command from root: # tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config qwen2/knowledge_distillation_distributed # # This config works best for distilling on 2+ devices. diff --git a/recipes/knowledge_distillation_distributed.py b/recipes/knowledge_distillation_distributed.py index 87342b9def..d17e480ba6 100644 --- a/recipes/knowledge_distillation_distributed.py +++ b/recipes/knowledge_distillation_distributed.py @@ -323,7 +323,7 @@ def setup(self, cfg: DictConfig) -> None: ) # Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) - # if cfg is missing profiler key or if `cfg.profiler.enabled = False + # if cfg is missing profiler key or if `cfg.profiler.enabled = False` self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) # Used to ignore labels for loss computation @@ -406,6 +406,7 @@ def _setup_model( fsdp_cpu_offload: bool, reshard_after_forward: bool, base_model_state_dict: Dict[str, Any], + custom_sharded_layers: Optional[List[str]] = None, lora_weights_state_dict: Optional[Dict[str, Any]] = None, ) -> nn.Module: """ @@ -443,28 +444,16 @@ def _setup_model( model, auto_wrap_policy={modules.TransformerSelfAttentionLayer} ) - # For FSDP sharding, we can condition on either the module or its name - # Shard conditions should be callables taking name (relative to model root) - # and the module itself and returning a bool on whether to shard the given module - - # Shard transformer decoder layers (or AC-wrapped versions) - # Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) - # But directly using the name is more concise - def _is_layer_name(name: str, module: nn.Module) -> bool: - """ - Return True for layers.i and False for all other module names - Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot - """ - name_list = name.split(".") - return ( - len(name_list) == 2 - and name_list[0] == "layers" - and str.isdigit(name_list[1]) + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, ) - + ] training.shard_model( model=model, - shard_conditions=[_is_layer_name], + shard_conditions=fsdp_shard_conditions, cpu_offload=fsdp_cpu_offload, reshard_after_forward=reshard_after_forward, ) @@ -562,27 +551,13 @@ def _setup_teacher_model( if self._compile: training.compile_model(model, verbose=self._is_rank_zero) - # For FSDP sharding, we can condition on either the module or its name - # Shard conditions should be callables taking name (relative to model root) - # and the module itself and returning a bool on whether to shard the given module - fsdp_shard_conditions = [] - - # Shard transformer decoder layers (or AC-wrapped versions) - # Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) - # But directly using the name is more concise - def _is_layer_fqn(s: str) -> bool: - """ - Return True for layers.i and False for all other module names - Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot - """ - s_list = s.split(".") - return len(s_list) == 2 and s_list[0] == "layers" and str.isdigit(s_list[1]) - - fsdp_shard_conditions = [lambda n, m: _is_layer_fqn(n)] - - if custom_sharded_layers: - fsdp_shard_conditions += [lambda n, m: n in custom_sharded_layers] - + # For FSDP sharding + fsdp_shard_conditions = [ + partial( + training.get_shard_conditions, + names_to_match=custom_sharded_layers, + ) + ] training.shard_model( model=model, shard_conditions=fsdp_shard_conditions, @@ -864,7 +839,7 @@ def train(self) -> None: # in case shuffle is True self._sampler.set_epoch(curr_epoch) - pbar = tqdm(total=self._steps_per_epoch) + pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) for idx, batch in enumerate(self._dataloader): if ( self.max_steps_per_epoch is not None @@ -965,6 +940,8 @@ def train(self) -> None: self.epochs_run += 1 self.save_checkpoint(epoch=curr_epoch) + self._profiler.stop() + def cleanup(self) -> None: if self._is_rank_zero: self._metric_logger.close() diff --git a/tests/recipes/test_knowledge_distillation_distributed.py b/tests/recipes/test_knowledge_distillation_distributed.py index ffffb47a3c..949883ac48 100644 --- a/tests/recipes/test_knowledge_distillation_distributed.py +++ b/tests/recipes/test_knowledge_distillation_distributed.py @@ -29,7 +29,7 @@ from torchtune import config -class TestKDDistributedDeviceRecipe: +class TestKDDistributedRecipe: def _get_test_config_overrides(self, epochs: int = 2): return [ "batch_size=4", @@ -91,8 +91,6 @@ def test_loss(self, tmpdir, monkeypatch): num_losses = int(len(loss_values) / 4) # 2 steps per epoch, 2 epochs loss_values = loss_values[0::num_losses] expected_loss_values = self._fetch_expected_loss_values("llama3") - print(loss_values) - print(expected_loss_values) torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-5, atol=1e-5 )